License:
The data used in this project was collected by the Orthopaedics Department of the University Hospital Schleswig-Holstein (UKSH) in Kiel, Germany. Approval for the use of this data was granted by the department director, Prof. Dr. Babak Moradi.
Portfolio Exam¶
Story¶
Welcome to Therapiecenter Weinstadt, where we are pioneering a new project at the intersection of technology and human movement analysis. Our goal is ambitious: can we determine if someone is male or female simply by analyzing their gait?
To answer this question, we are developing a Proof of Concept (POC) that leverages advanced deep learning techniques. We begin by gathering data from approximately 90 individuals. Instead of traditional markers, we used a markerless camera system to track each subject’s movements with precision, capturing detailed angles for each joint over multiple gait cycles. Each subject walks on a treadmill, allowing us to record more than one gait cycle per person. This method enabled us to build a dataset with a smaller group of participants but with a broad range of motion data for each individual, offering us diverse insights into their movement patterns.
By analyzing these angles and joint positions across the gait cycles, we aim to discover, gender-related differences that may exist within the mechanics of human movement. This project is not only about classification but also about expanding our understanding of human biomechanics and how it can differ between individuals.
Starting with a POC helps us keep things manageable and allows us to refine our approach with a small, focused dataset. If successful, this POC could pave the way for more comprehensive studies in the future, allowing us to apply this knowledge in clinical, therapeutic, or even sports performance settings.
At Therapiecenter Weinstadt, we believe this project has the potential to open new doors in understanding and analyzing human gait in ways previously unexplored.
# import all necessary libraries
import copy
import os
from PIL import Image
from IPython.display import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from sklearn import metrics
from sklearn.dummy import DummyClassifier
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.svm import SVC
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
# Display all columns
pd.options.display.max_columns = None
pd.options.display.max_rows = None
Google Colab¶
# mount drive
# from google.colab import drive
# drive.mount('/content/drive')
# data_raw = pd.read_json(r'/content/drive/MyDrive/HochschuleKiel/DeepLearning/Portfolio Exam/data/data_filtered.json', orient='records')
Privat Computer¶
data_raw = pd.read_json(r'../data/data_filtered.json', orient='records')
The Data¶
The dataset contains gait data collected from 89 unique subjects, each of whom walked on a treadmill in a controlled gait lab environment. During these sessions, each subject completed multiple gait cycles within approximately 30 seconds of walking. The aim of this data collection is to explore whether a deep learning model can distinguish between male and female subjects based solely on the features of an individual gait cycle. This project serves as a proof of concept for gender classification just a single gait cycle from a human gait. In the file DataCollection.ipynb is the first process of data collection, documented. In file ScriptClasses.py are helper functions stored, which are used for this project.
Explanation of the Features:
age: Age of the Subjectsubject: Unuiqe ID of the Subjectsex: Gender of the Subjectcreated: Date of the Gait Analysiscycle: Which Gait Cylce
Following the Left and Right Angles in X, Y & Z Axis of:
Ankle Angles: Reflects the range of motion in the ankle joint, measuring dorsiflexion and plantarflexion movements critical for balance and walking.CenterOfMass_corr: Represents the corrected position of the body's center of mass, which is essential for understanding stability and movement coordination.Elbow Angles: Measures the angle of flexion and extension at the elbow joint, indicating arm mobility and function.Foot Pitch Angles: Refers to the angle of the foot relative to the ground, particularly during the stance phase of walking, to analyze gait efficiency.Foot Progression: Describes the alignment of the foot in the direction of walking, showing toe-in or toe-out positioning in gait analysis.Hip Angles: Captures the hip joint’s range of motion, including flexion, extension, and rotation, essential for evaluating lower limb mobility.Knee Angles: Reflects the angle between the thigh and lower leg, providing insight into knee flexion and extension during gait and other movements.Pelvic Angles: Measures the orientation of the pelvis, including tilt, obliquity, and rotation, important for assessing balance and posture.Shoulder Angles: Reflects the movement of the shoulder joint, including abduction, flexion, and extension, relevant for upper body function.Thorax Angles: Represents the movement of the thoracic spine and upper body posture, crucial for overall balance and body alignment.Thorax_Lab Angles: A lab-specific measurement of the thoracic angles, potentially adjusted for precise posture and movement analysis
Further Information about the Joint Angles:
The left and right joint angles are stored as lists of 101 values, representing the joint angles throughout a complete gait cycle, sampled from 0% to 100%.
display(data_raw.head(3))
| age | subject | sex | created | cycle | Left Ankle Angles_X | Left Ankle Angles_Y | Left Ankle Angles_Z | Left CenterOfMass_corr_X | Left CenterOfMass_corr_Y | Left CenterOfMass_corr_Z | Left Elbow Angles_X | Left Elbow Angles_Y | Left Elbow Angles_Z | Left Foot Pitch Angles_X | Left Foot Pitch Angles_Y | Left Foot Pitch Angles_Z | Left Foot Progression_X | Left Foot Progression_Y | Left Foot Progression_Z | Left Hip Angles_X | Left Hip Angles_Y | Left Hip Angles_Z | Left Knee Angles_X | Left Knee Angles_Y | Left Knee Angles_Z | Left Pelvic Angles_X | Left Pelvic Angles_Y | Left Pelvic Angles_Z | Left Shoulder Angles_X | Left Shoulder Angles_Y | Left Shoulder Angles_Z | Left Thorax Angles_X | Left Thorax Angles_Y | Left Thorax Angles_Z | Left Thorax_Lab Angles_X | Left Thorax_Lab Angles_Y | Left Thorax_Lab Angles_Z | Right Ankle Angles_X | Right Ankle Angles_Y | Right Ankle Angles_Z | Right CenterOfMass_corr_X | Right CenterOfMass_corr_Y | Right CenterOfMass_corr_Z | Right Elbow Angles_X | Right Elbow Angles_Y | Right Elbow Angles_Z | Right Foot Pitch Angles_X | Right Foot Pitch Angles_Y | Right Foot Pitch Angles_Z | Right Foot Progression_X | Right Foot Progression_Y | Right Foot Progression_Z | Right Hip Angles_X | Right Hip Angles_Y | Right Hip Angles_Z | Right Knee Angles_X | Right Knee Angles_Y | Right Knee Angles_Z | Right Pelvic Angles_X | Right Pelvic Angles_Y | Right Pelvic Angles_Z | Right Shoulder Angles_X | Right Shoulder Angles_Y | Right Shoulder Angles_Z | Right Thorax Angles_X | Right Thorax Angles_Y | Right Thorax Angles_Z | Right Thorax_Lab Angles_X | Right Thorax_Lab Angles_Y | Right Thorax_Lab Angles_Z | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 19 | sub-002 | male | 2024-05-29 12:27:14 | cycle_1 | [-2.768442, -3.496685, -4.069201, -4.462231, -... | [7.935368, 8.292075, 8.685886, 9.081171, 9.438... | [0.6022689999999999, 0.629934, 0.660829, 0.667... | [0.0, -0.002226, -0.00464, -0.0072429999999999... | [0.0, -0.000268, -0.00040699999999999997, -0.0... | [0.0, -0.001321, -0.0026279999999999997, -0.00... | [22.490803, 22.188433, 21.868492, 21.554268, 2... | [-0.010251999999999999, -0.01488, -0.017207, -... | [104.20636, 103.466881, 102.595413, 101.67038,... | [14.214731, 13.85244, 13.169363, 12.26093, 11.... | [-3.179904, -3.617692, -4.20662, -4.875635, -5... | [19.032904, 18.2679, 17.341208, 16.317291, 15.... | [14.214735, 13.852443, 13.169367, 12.260934, 1... | [-3.179903, -3.61769, -4.206618, -4.875634, -5... | [-19.032904, -18.2679, -17.341208, -16.317291,... | [20.6766, 20.222738, 19.856031, 19.556093, 19.... | [-4.393382, -4.227548, -4.138162, -4.11265, -4... | [-0.280891, -1.543496, -2.484963, -3.042148, -... | [2.135388, 1.461672, 1.318257, 1.61889, 2.2538... | [-0.194929, -0.818533, -1.206393, -1.397462, -... | [-14.350356, -12.79824, -11.413175, -10.301519... | [3.991429, 3.891927, 3.795211, 3.698814, 3.599... | [-0.522435, -0.717429, -0.9033939999999999, -1... | [-4.647504, -4.119165, -3.555717, -2.977560999... | [2.708186, 2.7229010000000002, 2.779783, 2.872... | [-8.669715, -8.55989, -8.467485, -8.390489, -8... | [16.503153, 16.36441, 16.326563, 16.373936, 16... | [3.141546, 3.062634, 2.954683, 2.827985, 2.697... | [-1.945191, -1.821488, -1.716818, -1.627697, -... | [-2.919237, -3.102238, -3.3029479999999998, -3... | [7.254066, 7.08643, 6.894266, 6.684382, 6.4677... | [-2.086775, -2.146077, -2.215677, -2.286167, -... | [-7.7877279999999995, -7.421407, -7.040127, -6... | [-5.4825289999999995, -5.833911, -6.063342, -6... | [4.380441, 4.43795, 4.501745, 4.57004, 4.64473... | [2.074525, 3.183789, 4.164791, 4.954755, 5.505... | [0.0, 0.0015400000000000001, 0.003264, 0.00518... | [0.0, 0.000522, 0.001184, 0.001996, 0.002956, ... | [0.0, -0.000998, -0.0020039999999999997, -0.00... | [28.822729, 28.978001, 29.148277, 29.292614, 2... | [0.003681, 0.007298, 0.009037, 0.008681, 0.007... | [97.213318, 96.508972, 95.770767, 95.081993, 9... | [10.846501, 9.729046, 8.539912, 7.347647, 6.20... | [4.045348, 4.578834, 5.088349, 5.536511, 5.904... | [-13.617139, -12.353103, -10.997869, -9.585029... | [10.846505, 9.72905, 8.539916999999999, 7.3476... | [4.045347, 4.578834, 5.088348, 5.53651, 5.9045... | [-13.617139, -12.353103, -10.997869, -9.585029... | [18.893475, 19.141722, 19.27253, 19.289482, 19... | [-2.407445, -2.328364, -2.257316, -2.183507, -... | [-3.71332, -3.729039, -3.720179, -3.685369, -3... | [3.157154, 3.974303, 4.884859, 5.819755, 6.725... | [-0.8432679999999999, -0.736147, -0.5383439999... | [-15.543177, -15.934821, -16.170027, -16.20980... | [0.883767, 1.096725, 1.264113, 1.393558, 1.498... | [-0.365954, -0.43199099999999996, -0.456648999... | [4.044459, 4.620235, 5.210877, 5.8096049999999... | [-5.663974, -6.074879, -6.471653, -6.800329, -... | [-10.194664, -10.122283, -10.048591, -9.980521... | [8.034688, 7.914454, 7.848456, 7.869713, 7.992... | [7.059786, 6.799669, 6.570794, 6.367293, 6.177... | [-0.765127, -0.755408, -0.785045, -0.849924999... | [-4.957033, -5.034188, -5.127676, -5.235561, -... | [8.011764, 7.970358, 7.914809, 7.84646, 7.7663... | [-0.435287, -0.484906, -0.531291, -0.570133, -... | [-0.9707739999999999, -0.470057, 0.02395599999... |
| 1 | 19 | sub-002 | male | 2024-05-29 12:27:14 | cycle_10 | [-7.611665, -8.296784, -8.636337, -8.653679, -... | [6.345243, 6.41904, 6.53749, 6.682159, 6.82977... | [0.015241, 1.130675, 2.046575, 2.656725, 2.912... | [0.016838, 0.015413, 0.013836999999999999, 0.0... | [-0.009348, -0.009233, -0.008978, -0.008558, -... | [0.000483, -0.001058, -0.002572, -0.004029, -0... | [22.925438, 22.503048, 22.122334, 21.826393, 2... | [-0.019576, -0.016118999999999998, -0.01080699... | [102.04879, 101.1464, 100.409393, 99.857368, 9... | [12.341584, 12.335873, 12.083894, 11.615406, 1... | [-5.800112, -5.918781, -6.172589, -6.485641, -... | [26.423983, 25.275766, 23.834427, 22.164652, 2... | [12.341588, 12.335876, 12.083896, 11.615409, 1... | [-5.80011, -5.918779, -6.172588, -6.485639, -6... | [-26.423983, -25.275766, -23.834427, -22.16465... | [20.424934, 19.962803, 19.509947, 19.021643, 1... | [-1.797753, -1.688857, -1.708179, -1.814137999... | [-2.878228, -2.440343, -1.9427249999999998, -1... | [2.762847, 1.953323, 1.724262, 1.937817, 2.447... | [-0.236487, -0.9723259999999999, -1.39173, -1.... | [-18.446913, -19.233843, -19.68598, -19.714357... | [1.090251, 0.818932, 0.5202, 0.200094, -0.1280... | [-0.028227, -0.38070299999999996, -0.728997, -... | [-4.88499, -4.3810839999999995, -3.769862, -3.... | [3.134768, 3.105432, 3.044675, 2.940096, 2.789... | [-9.577947, -9.449483, -9.323687, -9.206041, -... | [18.54364, 18.746981, 18.860155, 18.887785, 18... | [4.173757, 4.379654, 4.612041, 4.860198, 5.105... | [-1.197732, -0.843793, -0.5067539999999999, -0... | [-2.178959, -2.3056609999999997, -2.478065, -2... | [5.30765, 5.244393, 5.181376, 5.113739, 5.0358... | [-1.020125, -1.011145, -1.010935, -1.014616, -... | [-7.164998, -6.754079, -6.28371, -5.74174, -5.... | [-5.628355, -5.81796, -5.811576, -5.627428, -5... | [3.403195, 3.262351, 3.242479, 3.355409, 3.601... | [-3.4450979999999998, -2.903352, -2.27135, -1.... | [-0.008744, -0.007463, -0.006006, -0.004360999... | [-0.012428999999999999, -0.01247, -0.012466, -... | [0.002436, 0.001673, 0.000812, -0.000139, -0.0... | [31.828133, 31.580637, 31.387247, 31.242411, 3... | [-0.0011359999999999999, 0.00102, 0.001377, 0.... | [90.242859, 90.552948, 91.0634, 91.718605, 92.... | [9.194756, 8.696275, 8.075597, 7.379229, 6.644... | [4.146616, 4.320601, 4.622187, 5.041539, 5.558... | [-21.498556, -20.357746, -19.085051, -17.74668... | [9.194759, 8.696277, 8.075601, 7.379233, 6.644... | [4.146615, 4.3206, 4.622186, 5.041538, 5.55823... | [-21.498556, -20.357746, -19.085051, -17.74668... | [18.847454, 18.874477, 18.864149, 18.821054, 1... | [-5.771685, -5.643616, -5.501174, -5.346303, -... | [-7.284528, -7.00043, -6.623891, -6.122418, -5... | [7.355994, 7.557222, 8.092501, 8.881246, 9.836... | [2.684278, 2.574164, 2.540403, 2.580063, 2.684... | [-10.199863, -10.37021, -10.588555, -10.881887... | [-1.597058, -1.574028, -1.5760939999999999, -1... | [-1.239754, -1.412793, -1.553985, -1.665565, -... | [0.427455, 0.937423, 1.446885, 1.9577300000000... | [-3.041614, -3.081234, -3.148071, -3.257658, -... | [-10.598993, -10.551416, -10.489378, -10.40383... | [10.954481, 10.959135, 10.947763, 10.889692, 1... | [8.854305, 8.837968, 8.840726, 8.842192, 8.807... | [-1.246204, -1.063445, -0.88775, -0.716517, -0... | [-5.760704, -5.845308, -5.895671, -5.918782, -... | [7.472655, 7.480461, 7.478623, 7.462115, 7.425... | [-1.736641, -1.716788, -1.6773069999999999, -1... | [-5.447077, -4.997944, -4.516113, -4.006367, -... |
| 2 | 19 | sub-002 | male | 2024-05-29 12:27:14 | cycle_11 | [-4.634666, -5.02648, -5.340784, -5.544887, -5... | [7.019156, 7.207421, 7.360003, 7.428865, 7.382... | [0.8590829999999999, 0.9442879999999999, 1.291... | [0.020581, 0.018594, 0.016465, 0.014183, 0.011... | [-0.041350000000000005, -0.04147000000000001, ... | [-0.002153, -0.003436, -0.004659, -0.005737, -... | [21.892666, 21.667431, 21.510859, 21.407866, 2... | [0.006052, 0.0068579999999999995, 0.0050479999... | [103.10041, 103.414948, 103.769325, 104.036888... | [13.801257, 13.364927, 12.614091, 11.681816, 1... | [-7.515644, -7.981946, -8.420057, -8.740601999... | [19.812559, 19.102261, 18.223965, 17.228157, 1... | [13.801261, 13.36493, 12.614095, 11.681819, 10... | [-7.515642, -7.981944, -8.420057, -8.740601999... | [-19.812559, -19.102261, -18.223965, -17.22815... | [21.490784, 21.206192, 20.810915, 20.239679, 1... | [-1.003782, -1.046244, -1.167168, -1.336805, -... | [1.397087, 2.274668, 2.839696, 3.0362, 2.89127... | [4.38941, 4.258365, 4.485112, 4.878491, 5.2717... | [1.547475, 1.425399, 1.3594089999999999, 1.309... | [-18.630737, -19.466963, -20.113201, -20.44350... | [1.38354, 1.262527, 1.063718, 0.784248, 0.4453... | [0.217076, -0.012333, -0.265026, -0.52868, -0.... | [-3.368124, -2.697494, -1.98144, -1.229187, -0... | [1.134091, 0.990254, 0.783542, 0.540679, 0.294... | [-10.493185, -10.305758, -10.112445, -9.931076... | [17.463486, 17.312468, 17.166122, 17.092224, 1... | [3.597836, 3.621159, 3.705733, 3.861809, 4.075... | [-1.154968, -0.9766429999999999, -0.759751, -0... | [-2.844895, -2.9492849999999997, -3.022531, -3... | [5.022056, 4.928493, 4.817266, 4.69589, 4.5721... | [-0.685429, -0.7331249999999999, -0.7694989999... | [-6.30276, -5.719454, -5.056945, -4.33003, -3.... | [-6.689217, -7.057434, -7.268858, -7.303356, -... | [4.002138, 4.135075, 4.315419, 4.513762, 4.705... | [-1.061941, -0.874251, -0.6129100000000001, -0... | [-0.010591, -0.008883, -0.006986999999999999, ... | [-0.013071, -0.014, -0.014787, -0.015403, -0.0... | [0.002051, 0.001292, 0.0005059999999999999, -0... | [33.081654, 32.856186, 32.723984, 32.691216, 3... | [0.005253, 0.00417, 0.002565, 0.001298, 0.0009... | [92.021736, 92.243813, 92.391449, 92.442314, 9... | [8.570474, 7.685758, 6.655786, 5.567538, 4.500... | [4.459383, 4.893425, 5.360785, 5.814991, 6.221... | [-14.959339, -13.506869, -12.001929, -10.49032... | [8.570477, 7.685762, 6.65579, 5.56754199999999... | [4.459382, 4.893425, 5.360784, 5.81499, 6.2219... | [-14.959339, -13.506869, -12.001929, -10.49032... | [16.00716, 16.189598, 16.405577, 16.658844, 16... | [-2.608135, -2.575316, -2.595477, -2.666513, -... | [-5.038992, -5.088731, -5.152913, -5.168307, -... | [3.29341, 3.71739, 4.438928, 5.395056, 6.51829... | [0.818074, 0.758255, 0.774924, 0.854435, 0.989... | [-9.693586, -9.014916, -8.393165, -7.926613, -... | [-1.286025, -1.090904, -0.854332, -0.579944999... | [-0.204015, -0.38847899999999996, -0.562819, -... | [1.125047, 1.8016709999999998, 2.525921, 3.273... | [-4.353323, -4.315862, -4.385724, -4.564363, -... | [-10.74014, -10.643498, -10.525556, -10.392888... | [13.31672, 13.247898, 13.045144, 12.724239, 12... | [8.859, 8.63991, 8.336942, 7.95187, 7.500204, ... | [0.103183, 0.21866000000000002, 0.307375, 0.37... | [-5.57998, -5.7986889999999995, -6.033959, -6.... | [7.547276, 7.527881, 7.468235, 7.366483, 7.225... | [0.632976, 0.589656, 0.5276609999999999, 0.448... | [-4.3928519999999995, -3.91817, -3.416698, -2.... |
Initial Data Analysis¶
Now we take a closer look at the data.
As we can see we have over 3600 gait cylces and 71 features. None of those features contain missing values.
We can also observe that the columns are not equally, which means that for some gait cylces the measurement could not be tracked/calculated or contained missing/Nan values and got cleaned beforehand. Subject 69 has the most gait cycles in the data. Female is the most present gender in the data. On the 08-08-2023 the most gait cycles where captured. Cycle 01 is most present in the data.
From the age column, we observe that the average age of the subjects is 54.34 years, with a standard deviation of 24.61. Ages range from a minimum of 0 to a maximum of 93 years. The majority of subjects fall within the age range of 30 to 74 years. The minimum age of 0 suggests the presence of outliers, which we will address during data preprocessing by removing these entries.
print(f'Shape of the data: {data_raw.shape}')
display(data_raw.info())
display(data_raw.describe(include='object'))
display(data_raw.describe(include='number'))
Shape of the data: (3656, 71) <class 'pandas.core.frame.DataFrame'> RangeIndex: 3656 entries, 0 to 3655 Data columns (total 71 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 age 3656 non-null int64 1 subject 3656 non-null object 2 sex 3656 non-null object 3 created 3656 non-null object 4 cycle 3656 non-null object 5 Left Ankle Angles_X 3656 non-null object 6 Left Ankle Angles_Y 3656 non-null object 7 Left Ankle Angles_Z 3656 non-null object 8 Left CenterOfMass_corr_X 3656 non-null object 9 Left CenterOfMass_corr_Y 3656 non-null object 10 Left CenterOfMass_corr_Z 3656 non-null object 11 Left Elbow Angles_X 3656 non-null object 12 Left Elbow Angles_Y 3656 non-null object 13 Left Elbow Angles_Z 3656 non-null object 14 Left Foot Pitch Angles_X 3656 non-null object 15 Left Foot Pitch Angles_Y 3656 non-null object 16 Left Foot Pitch Angles_Z 3656 non-null object 17 Left Foot Progression_X 3656 non-null object 18 Left Foot Progression_Y 3656 non-null object 19 Left Foot Progression_Z 3656 non-null object 20 Left Hip Angles_X 3656 non-null object 21 Left Hip Angles_Y 3656 non-null object 22 Left Hip Angles_Z 3656 non-null object 23 Left Knee Angles_X 3656 non-null object 24 Left Knee Angles_Y 3656 non-null object 25 Left Knee Angles_Z 3656 non-null object 26 Left Pelvic Angles_X 3656 non-null object 27 Left Pelvic Angles_Y 3656 non-null object 28 Left Pelvic Angles_Z 3656 non-null object 29 Left Shoulder Angles_X 3656 non-null object 30 Left Shoulder Angles_Y 3656 non-null object 31 Left Shoulder Angles_Z 3656 non-null object 32 Left Thorax Angles_X 3656 non-null object 33 Left Thorax Angles_Y 3656 non-null object 34 Left Thorax Angles_Z 3656 non-null object 35 Left Thorax_Lab Angles_X 3656 non-null object 36 Left Thorax_Lab Angles_Y 3656 non-null object 37 Left Thorax_Lab Angles_Z 3656 non-null object 38 Right Ankle Angles_X 3656 non-null object 39 Right Ankle Angles_Y 3656 non-null object 40 Right Ankle Angles_Z 3656 non-null object 41 Right CenterOfMass_corr_X 3656 non-null object 42 Right CenterOfMass_corr_Y 3656 non-null object 43 Right CenterOfMass_corr_Z 3656 non-null object 44 Right Elbow Angles_X 3656 non-null object 45 Right Elbow Angles_Y 3656 non-null object 46 Right Elbow Angles_Z 3656 non-null object 47 Right Foot Pitch Angles_X 3656 non-null object 48 Right Foot Pitch Angles_Y 3656 non-null object 49 Right Foot Pitch Angles_Z 3656 non-null object 50 Right Foot Progression_X 3656 non-null object 51 Right Foot Progression_Y 3656 non-null object 52 Right Foot Progression_Z 3656 non-null object 53 Right Hip Angles_X 3656 non-null object 54 Right Hip Angles_Y 3656 non-null object 55 Right Hip Angles_Z 3656 non-null object 56 Right Knee Angles_X 3656 non-null object 57 Right Knee Angles_Y 3656 non-null object 58 Right Knee Angles_Z 3656 non-null object 59 Right Pelvic Angles_X 3656 non-null object 60 Right Pelvic Angles_Y 3656 non-null object 61 Right Pelvic Angles_Z 3656 non-null object 62 Right Shoulder Angles_X 3656 non-null object 63 Right Shoulder Angles_Y 3656 non-null object 64 Right Shoulder Angles_Z 3656 non-null object 65 Right Thorax Angles_X 3656 non-null object 66 Right Thorax Angles_Y 3656 non-null object 67 Right Thorax Angles_Z 3656 non-null object 68 Right Thorax_Lab Angles_X 3656 non-null object 69 Right Thorax_Lab Angles_Y 3656 non-null object 70 Right Thorax_Lab Angles_Z 3656 non-null object dtypes: int64(1), object(70) memory usage: 2.0+ MB
None
| subject | sex | created | cycle | Left Ankle Angles_X | Left Ankle Angles_Y | Left Ankle Angles_Z | Left CenterOfMass_corr_X | Left CenterOfMass_corr_Y | Left CenterOfMass_corr_Z | Left Elbow Angles_X | Left Elbow Angles_Y | Left Elbow Angles_Z | Left Foot Pitch Angles_X | Left Foot Pitch Angles_Y | Left Foot Pitch Angles_Z | Left Foot Progression_X | Left Foot Progression_Y | Left Foot Progression_Z | Left Hip Angles_X | Left Hip Angles_Y | Left Hip Angles_Z | Left Knee Angles_X | Left Knee Angles_Y | Left Knee Angles_Z | Left Pelvic Angles_X | Left Pelvic Angles_Y | Left Pelvic Angles_Z | Left Shoulder Angles_X | Left Shoulder Angles_Y | Left Shoulder Angles_Z | Left Thorax Angles_X | Left Thorax Angles_Y | Left Thorax Angles_Z | Left Thorax_Lab Angles_X | Left Thorax_Lab Angles_Y | Left Thorax_Lab Angles_Z | Right Ankle Angles_X | Right Ankle Angles_Y | Right Ankle Angles_Z | Right CenterOfMass_corr_X | Right CenterOfMass_corr_Y | Right CenterOfMass_corr_Z | Right Elbow Angles_X | Right Elbow Angles_Y | Right Elbow Angles_Z | Right Foot Pitch Angles_X | Right Foot Pitch Angles_Y | Right Foot Pitch Angles_Z | Right Foot Progression_X | Right Foot Progression_Y | Right Foot Progression_Z | Right Hip Angles_X | Right Hip Angles_Y | Right Hip Angles_Z | Right Knee Angles_X | Right Knee Angles_Y | Right Knee Angles_Z | Right Pelvic Angles_X | Right Pelvic Angles_Y | Right Pelvic Angles_Z | Right Shoulder Angles_X | Right Shoulder Angles_Y | Right Shoulder Angles_Z | Right Thorax Angles_X | Right Thorax Angles_Y | Right Thorax Angles_Z | Right Thorax_Lab Angles_X | Right Thorax_Lab Angles_Y | Right Thorax_Lab Angles_Z | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 | 3656 |
| unique | 89 | 2 | 91 | 137 | 2921 | 2921 | 2921 | 2004 | 2004 | 2004 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2921 | 2957 | 2957 | 2957 | 2083 | 2083 | 2083 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 | 2957 |
| top | sub-069 | female | 2023-08-08 12:44:55 | cycle_1 | [6.893189, 6.573399, 6.306869, 6.095495, 5.940... | [2.444143, 2.506672, 2.569635, 2.63392, 2.6998... | [11.521898, 11.749121, 12.024985, 12.331873, 1... | [0.011529, 0.011380999999999999, 0.01116799999... | [-0.010992, -0.011198, -0.011365, -0.011496, -... | [0.006470000000000001, 0.005605, 0.004847, 0.0... | [67.722382, 67.841942, 67.953278, 68.052673, 6... | [0.01161, 0.008588, 0.005424, 0.00264499999999... | [67.336884, 67.013924, 66.85788, 66.890991, 67... | [4.156376, 4.098315, 3.997932, 3.86376, 3.7056... | [-3.6453759999999997, -3.7622679999999997, -3.... | [-1.318557, -1.5197289999999999, -1.7227109999... | [4.156377, 4.098315, 3.9979329999999997, 3.863... | [-3.6453759999999997, -3.7622679999999997, -3.... | [1.318557, 1.5197289999999999, 1.7227109999999... | [31.564228, 31.398096, 31.248755, 31.126377, 3... | [7.359384, 7.375715, 7.397888, 7.422212, 7.445... | [1.3239239999999999, 1.192659, 1.048491, 0.902... | [22.173082, 21.862728, 21.636713, 21.493454, 2... | [2.108773, 2.00006, 1.9031889999999998, 1.8254... | [-7.5444320000000005, -7.579482, -7.648215, -7... | [12.448437, 12.351758, 12.283578, 12.248773, 1... | [6.319764, 6.260339, 6.21349, 6.176859, 6.1474... | [-6.090747, -5.965488, -5.839082, -5.710993, -... | [18.469509, 18.263018, 18.057535, 17.853489, 1... | [-36.644535, -36.599392, -36.554646, -36.50981... | [-6.003837, -6.064888, -6.121649, -6.179747, -... | [14.478528, 14.447725, 14.380266, 14.271049, 1... | [-1.5006400000000002, -1.540162, -1.590595, -1... | [1.819955, 1.7617509999999998, 1.7148569999999... | [26.989492, 26.855251, 26.713287, 26.563417, 2... | [4.1572379999999995, 4.090972, 4.022262, 3.951... | [-5.144996, -5.084937, -5.018147, -4.943727, -... | [14.523161, 14.508266, 14.497878, 14.492327, 1... | [-11.160184, -11.269531, -11.366331, -11.44471... | [13.599168, 13.521851999999999, 13.468516, 13.... | [0.00047000000000000004, 0.0005279999999999999... | [0.001856, 0.0019459999999999998, 0.0019749999... | [0.006651, 0.005778, 0.004824999999999999, 0.0... | [61.774261, 61.173229, 60.599789, 60.090508, 5... | [0.019167999999999998, 0.016928, 0.015477, 0.0... | [97.601425, 97.501183, 97.333618, 97.064056, 9... | [0.220742, 0.201698, 0.188475, 0.180799, 0.178... | [-6.223454, -6.310985, -6.387176, -6.447673, -... | [9.090084, 9.080581, 9.085161, 9.10679, 9.1477... | [0.220744, 0.201699, 0.188476, 0.180801, 0.178... | [-6.223454, -6.310985, -6.387176, -6.447673, -... | [9.090084, 9.080581, 9.085161, 9.10679, 9.1477... | [20.799522, 20.793251, 20.7437, 20.65752, 20.5... | [-3.10834, -3.100393, -3.090199, -3.077057, -3... | [-7.478582, -7.461316, -7.4582429999999995, -7... | [20.235523, 20.212086, 20.171434, 20.115973, 2... | [3.360786, 3.335346, 3.310877, 3.287839, 3.266... | [-9.518766, -9.45668, -9.397385, -9.348419, -9... | [12.843493, 12.866432, 12.861176, 12.832086, 1... | [-5.967627, -5.997739, -6.024588, -6.047569, -... | [13.495292, 13.477127, 13.465643, 13.461199, 1... | [21.69734, 22.23702, 22.759691, 23.227921, 23.... | [-31.540339, -31.487875, -31.409733, -31.32809... | [-19.184639, -19.156521, -19.19298, -19.269421... | [12.455172, 12.51285, 12.602796, 12.717604, 12... | [-0.8873709999999999, -0.8704419999999999, -0.... | [-7.802131, -7.800966, -7.7954740000000005, -7... | [25.967037, 26.047792, 26.133192, 26.220331, 2... | [-3.387635, -3.392554, -3.40273, -3.418771, -3... | [6.05367, 6.047386, 6.048463, 6.056428, 6.0703... |
| freq | 134 | 1930 | 134 | 112 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 |
| age | |
|---|---|
| count | 3656.000000 |
| mean | 54.342177 |
| std | 24.618695 |
| min | 0.000000 |
| 25% | 30.000000 |
| 50% | 62.000000 |
| 75% | 74.000000 |
| max | 93.000000 |
Preprocessing¶
In the following chapter we preprocess the data.
display(data_raw.head(2))
| age | subject | sex | created | cycle | Left Ankle Angles_X | Left Ankle Angles_Y | Left Ankle Angles_Z | Left CenterOfMass_corr_X | Left CenterOfMass_corr_Y | Left CenterOfMass_corr_Z | Left Elbow Angles_X | Left Elbow Angles_Y | Left Elbow Angles_Z | Left Foot Pitch Angles_X | Left Foot Pitch Angles_Y | Left Foot Pitch Angles_Z | Left Foot Progression_X | Left Foot Progression_Y | Left Foot Progression_Z | Left Hip Angles_X | Left Hip Angles_Y | Left Hip Angles_Z | Left Knee Angles_X | Left Knee Angles_Y | Left Knee Angles_Z | Left Pelvic Angles_X | Left Pelvic Angles_Y | Left Pelvic Angles_Z | Left Shoulder Angles_X | Left Shoulder Angles_Y | Left Shoulder Angles_Z | Left Thorax Angles_X | Left Thorax Angles_Y | Left Thorax Angles_Z | Left Thorax_Lab Angles_X | Left Thorax_Lab Angles_Y | Left Thorax_Lab Angles_Z | Right Ankle Angles_X | Right Ankle Angles_Y | Right Ankle Angles_Z | Right CenterOfMass_corr_X | Right CenterOfMass_corr_Y | Right CenterOfMass_corr_Z | Right Elbow Angles_X | Right Elbow Angles_Y | Right Elbow Angles_Z | Right Foot Pitch Angles_X | Right Foot Pitch Angles_Y | Right Foot Pitch Angles_Z | Right Foot Progression_X | Right Foot Progression_Y | Right Foot Progression_Z | Right Hip Angles_X | Right Hip Angles_Y | Right Hip Angles_Z | Right Knee Angles_X | Right Knee Angles_Y | Right Knee Angles_Z | Right Pelvic Angles_X | Right Pelvic Angles_Y | Right Pelvic Angles_Z | Right Shoulder Angles_X | Right Shoulder Angles_Y | Right Shoulder Angles_Z | Right Thorax Angles_X | Right Thorax Angles_Y | Right Thorax Angles_Z | Right Thorax_Lab Angles_X | Right Thorax_Lab Angles_Y | Right Thorax_Lab Angles_Z | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 19 | sub-002 | male | 2024-05-29 12:27:14 | cycle_1 | [-2.768442, -3.496685, -4.069201, -4.462231, -... | [7.935368, 8.292075, 8.685886, 9.081171, 9.438... | [0.6022689999999999, 0.629934, 0.660829, 0.667... | [0.0, -0.002226, -0.00464, -0.0072429999999999... | [0.0, -0.000268, -0.00040699999999999997, -0.0... | [0.0, -0.001321, -0.0026279999999999997, -0.00... | [22.490803, 22.188433, 21.868492, 21.554268, 2... | [-0.010251999999999999, -0.01488, -0.017207, -... | [104.20636, 103.466881, 102.595413, 101.67038,... | [14.214731, 13.85244, 13.169363, 12.26093, 11.... | [-3.179904, -3.617692, -4.20662, -4.875635, -5... | [19.032904, 18.2679, 17.341208, 16.317291, 15.... | [14.214735, 13.852443, 13.169367, 12.260934, 1... | [-3.179903, -3.61769, -4.206618, -4.875634, -5... | [-19.032904, -18.2679, -17.341208, -16.317291,... | [20.6766, 20.222738, 19.856031, 19.556093, 19.... | [-4.393382, -4.227548, -4.138162, -4.11265, -4... | [-0.280891, -1.543496, -2.484963, -3.042148, -... | [2.135388, 1.461672, 1.318257, 1.61889, 2.2538... | [-0.194929, -0.818533, -1.206393, -1.397462, -... | [-14.350356, -12.79824, -11.413175, -10.301519... | [3.991429, 3.891927, 3.795211, 3.698814, 3.599... | [-0.522435, -0.717429, -0.9033939999999999, -1... | [-4.647504, -4.119165, -3.555717, -2.977560999... | [2.708186, 2.7229010000000002, 2.779783, 2.872... | [-8.669715, -8.55989, -8.467485, -8.390489, -8... | [16.503153, 16.36441, 16.326563, 16.373936, 16... | [3.141546, 3.062634, 2.954683, 2.827985, 2.697... | [-1.945191, -1.821488, -1.716818, -1.627697, -... | [-2.919237, -3.102238, -3.3029479999999998, -3... | [7.254066, 7.08643, 6.894266, 6.684382, 6.4677... | [-2.086775, -2.146077, -2.215677, -2.286167, -... | [-7.7877279999999995, -7.421407, -7.040127, -6... | [-5.4825289999999995, -5.833911, -6.063342, -6... | [4.380441, 4.43795, 4.501745, 4.57004, 4.64473... | [2.074525, 3.183789, 4.164791, 4.954755, 5.505... | [0.0, 0.0015400000000000001, 0.003264, 0.00518... | [0.0, 0.000522, 0.001184, 0.001996, 0.002956, ... | [0.0, -0.000998, -0.0020039999999999997, -0.00... | [28.822729, 28.978001, 29.148277, 29.292614, 2... | [0.003681, 0.007298, 0.009037, 0.008681, 0.007... | [97.213318, 96.508972, 95.770767, 95.081993, 9... | [10.846501, 9.729046, 8.539912, 7.347647, 6.20... | [4.045348, 4.578834, 5.088349, 5.536511, 5.904... | [-13.617139, -12.353103, -10.997869, -9.585029... | [10.846505, 9.72905, 8.539916999999999, 7.3476... | [4.045347, 4.578834, 5.088348, 5.53651, 5.9045... | [-13.617139, -12.353103, -10.997869, -9.585029... | [18.893475, 19.141722, 19.27253, 19.289482, 19... | [-2.407445, -2.328364, -2.257316, -2.183507, -... | [-3.71332, -3.729039, -3.720179, -3.685369, -3... | [3.157154, 3.974303, 4.884859, 5.819755, 6.725... | [-0.8432679999999999, -0.736147, -0.5383439999... | [-15.543177, -15.934821, -16.170027, -16.20980... | [0.883767, 1.096725, 1.264113, 1.393558, 1.498... | [-0.365954, -0.43199099999999996, -0.456648999... | [4.044459, 4.620235, 5.210877, 5.8096049999999... | [-5.663974, -6.074879, -6.471653, -6.800329, -... | [-10.194664, -10.122283, -10.048591, -9.980521... | [8.034688, 7.914454, 7.848456, 7.869713, 7.992... | [7.059786, 6.799669, 6.570794, 6.367293, 6.177... | [-0.765127, -0.755408, -0.785045, -0.849924999... | [-4.957033, -5.034188, -5.127676, -5.235561, -... | [8.011764, 7.970358, 7.914809, 7.84646, 7.7663... | [-0.435287, -0.484906, -0.531291, -0.570133, -... | [-0.9707739999999999, -0.470057, 0.02395599999... |
| 1 | 19 | sub-002 | male | 2024-05-29 12:27:14 | cycle_10 | [-7.611665, -8.296784, -8.636337, -8.653679, -... | [6.345243, 6.41904, 6.53749, 6.682159, 6.82977... | [0.015241, 1.130675, 2.046575, 2.656725, 2.912... | [0.016838, 0.015413, 0.013836999999999999, 0.0... | [-0.009348, -0.009233, -0.008978, -0.008558, -... | [0.000483, -0.001058, -0.002572, -0.004029, -0... | [22.925438, 22.503048, 22.122334, 21.826393, 2... | [-0.019576, -0.016118999999999998, -0.01080699... | [102.04879, 101.1464, 100.409393, 99.857368, 9... | [12.341584, 12.335873, 12.083894, 11.615406, 1... | [-5.800112, -5.918781, -6.172589, -6.485641, -... | [26.423983, 25.275766, 23.834427, 22.164652, 2... | [12.341588, 12.335876, 12.083896, 11.615409, 1... | [-5.80011, -5.918779, -6.172588, -6.485639, -6... | [-26.423983, -25.275766, -23.834427, -22.16465... | [20.424934, 19.962803, 19.509947, 19.021643, 1... | [-1.797753, -1.688857, -1.708179, -1.814137999... | [-2.878228, -2.440343, -1.9427249999999998, -1... | [2.762847, 1.953323, 1.724262, 1.937817, 2.447... | [-0.236487, -0.9723259999999999, -1.39173, -1.... | [-18.446913, -19.233843, -19.68598, -19.714357... | [1.090251, 0.818932, 0.5202, 0.200094, -0.1280... | [-0.028227, -0.38070299999999996, -0.728997, -... | [-4.88499, -4.3810839999999995, -3.769862, -3.... | [3.134768, 3.105432, 3.044675, 2.940096, 2.789... | [-9.577947, -9.449483, -9.323687, -9.206041, -... | [18.54364, 18.746981, 18.860155, 18.887785, 18... | [4.173757, 4.379654, 4.612041, 4.860198, 5.105... | [-1.197732, -0.843793, -0.5067539999999999, -0... | [-2.178959, -2.3056609999999997, -2.478065, -2... | [5.30765, 5.244393, 5.181376, 5.113739, 5.0358... | [-1.020125, -1.011145, -1.010935, -1.014616, -... | [-7.164998, -6.754079, -6.28371, -5.74174, -5.... | [-5.628355, -5.81796, -5.811576, -5.627428, -5... | [3.403195, 3.262351, 3.242479, 3.355409, 3.601... | [-3.4450979999999998, -2.903352, -2.27135, -1.... | [-0.008744, -0.007463, -0.006006, -0.004360999... | [-0.012428999999999999, -0.01247, -0.012466, -... | [0.002436, 0.001673, 0.000812, -0.000139, -0.0... | [31.828133, 31.580637, 31.387247, 31.242411, 3... | [-0.0011359999999999999, 0.00102, 0.001377, 0.... | [90.242859, 90.552948, 91.0634, 91.718605, 92.... | [9.194756, 8.696275, 8.075597, 7.379229, 6.644... | [4.146616, 4.320601, 4.622187, 5.041539, 5.558... | [-21.498556, -20.357746, -19.085051, -17.74668... | [9.194759, 8.696277, 8.075601, 7.379233, 6.644... | [4.146615, 4.3206, 4.622186, 5.041538, 5.55823... | [-21.498556, -20.357746, -19.085051, -17.74668... | [18.847454, 18.874477, 18.864149, 18.821054, 1... | [-5.771685, -5.643616, -5.501174, -5.346303, -... | [-7.284528, -7.00043, -6.623891, -6.122418, -5... | [7.355994, 7.557222, 8.092501, 8.881246, 9.836... | [2.684278, 2.574164, 2.540403, 2.580063, 2.684... | [-10.199863, -10.37021, -10.588555, -10.881887... | [-1.597058, -1.574028, -1.5760939999999999, -1... | [-1.239754, -1.412793, -1.553985, -1.665565, -... | [0.427455, 0.937423, 1.446885, 1.9577300000000... | [-3.041614, -3.081234, -3.148071, -3.257658, -... | [-10.598993, -10.551416, -10.489378, -10.40383... | [10.954481, 10.959135, 10.947763, 10.889692, 1... | [8.854305, 8.837968, 8.840726, 8.842192, 8.807... | [-1.246204, -1.063445, -0.88775, -0.716517, -0... | [-5.760704, -5.845308, -5.895671, -5.918782, -... | [7.472655, 7.480461, 7.478623, 7.462115, 7.425... | [-1.736641, -1.716788, -1.6773069999999999, -1... | [-5.447077, -4.997944, -4.516113, -4.006367, -... |
To identify the unique subject at the unique time, we combine the features 'subject' and 'created'.
data_raw['session_name'] = data_raw['subject'] + '_' + data_raw['created']
Then we can drop the columns 'subject' and 'created'
data_raw.drop(['subject', 'created'], axis=1, inplace=True)
Next we check if the lists in the angle columns have nan values and have the same length of 101.
cols = data_raw.columns.tolist()
col_to_check = cols[4:] # check all columns with the angles
# Überprüfen, ob alle Listen die gleiche Länge haben
lengths = data_raw[col_to_check].map(len)
print(lengths.value_counts())
# Konvertieren aller Listen in numpy-Arrays
data_raw[col_to_check] = data_raw[col_to_check].map(lambda x: np.array(x) if isinstance(x, list) else x)
# Überprüfen, ob es NaN-Werte in den Listen gibt
nan_counts = data_raw[col_to_check].map(lambda x: np.isnan(x).any() if isinstance(x, np.ndarray) else False)
print(nan_counts.value_counts())
Left Ankle Angles_Y Left Ankle Angles_Z Left CenterOfMass_corr_X Left CenterOfMass_corr_Y Left CenterOfMass_corr_Z Left Elbow Angles_X Left Elbow Angles_Y Left Elbow Angles_Z Left Foot Pitch Angles_X Left Foot Pitch Angles_Y Left Foot Pitch Angles_Z Left Foot Progression_X Left Foot Progression_Y Left Foot Progression_Z Left Hip Angles_X Left Hip Angles_Y Left Hip Angles_Z Left Knee Angles_X Left Knee Angles_Y Left Knee Angles_Z Left Pelvic Angles_X Left Pelvic Angles_Y Left Pelvic Angles_Z Left Shoulder Angles_X Left Shoulder Angles_Y Left Shoulder Angles_Z Left Thorax Angles_X Left Thorax Angles_Y Left Thorax Angles_Z Left Thorax_Lab Angles_X Left Thorax_Lab Angles_Y Left Thorax_Lab Angles_Z Right Ankle Angles_X Right Ankle Angles_Y Right Ankle Angles_Z Right CenterOfMass_corr_X Right CenterOfMass_corr_Y Right CenterOfMass_corr_Z Right Elbow Angles_X Right Elbow Angles_Y Right Elbow Angles_Z Right Foot Pitch Angles_X Right Foot Pitch Angles_Y Right Foot Pitch Angles_Z Right Foot Progression_X Right Foot Progression_Y Right Foot Progression_Z Right Hip Angles_X Right Hip Angles_Y Right Hip Angles_Z Right Knee Angles_X Right Knee Angles_Y Right Knee Angles_Z Right Pelvic Angles_X Right Pelvic Angles_Y Right Pelvic Angles_Z Right Shoulder Angles_X Right Shoulder Angles_Y Right Shoulder Angles_Z Right Thorax Angles_X Right Thorax Angles_Y Right Thorax Angles_Z Right Thorax_Lab Angles_X Right Thorax_Lab Angles_Y Right Thorax_Lab Angles_Z session_name 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 101 27 3656 Name: count, dtype: int64 Left Ankle Angles_Y Left Ankle Angles_Z Left CenterOfMass_corr_X Left CenterOfMass_corr_Y Left CenterOfMass_corr_Z Left Elbow Angles_X Left Elbow Angles_Y Left Elbow Angles_Z Left Foot Pitch Angles_X Left Foot Pitch Angles_Y Left Foot Pitch Angles_Z Left Foot Progression_X Left Foot Progression_Y Left Foot Progression_Z Left Hip Angles_X Left Hip Angles_Y Left Hip Angles_Z Left Knee Angles_X Left Knee Angles_Y Left Knee Angles_Z Left Pelvic Angles_X Left Pelvic Angles_Y Left Pelvic Angles_Z Left Shoulder Angles_X Left Shoulder Angles_Y Left Shoulder Angles_Z Left Thorax Angles_X Left Thorax Angles_Y Left Thorax Angles_Z Left Thorax_Lab Angles_X Left Thorax_Lab Angles_Y Left Thorax_Lab Angles_Z Right Ankle Angles_X Right Ankle Angles_Y Right Ankle Angles_Z Right CenterOfMass_corr_X Right CenterOfMass_corr_Y Right CenterOfMass_corr_Z Right Elbow Angles_X Right Elbow Angles_Y Right Elbow Angles_Z Right Foot Pitch Angles_X Right Foot Pitch Angles_Y Right Foot Pitch Angles_Z Right Foot Progression_X Right Foot Progression_Y Right Foot Progression_Z Right Hip Angles_X Right Hip Angles_Y Right Hip Angles_Z Right Knee Angles_X Right Knee Angles_Y Right Knee Angles_Z Right Pelvic Angles_X Right Pelvic Angles_Y Right Pelvic Angles_Z Right Shoulder Angles_X Right Shoulder Angles_Y Right Shoulder Angles_Z Right Thorax Angles_X Right Thorax Angles_Y Right Thorax Angles_Z Right Thorax_Lab Angles_X Right Thorax_Lab Angles_Y Right Thorax_Lab Angles_Z session_name False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False 3656 Name: count, dtype: int64
Next we do some feature engineering and calculate the age of each subject.
# bring the last column to the first position
cols = data_raw.columns.tolist()
cols = cols[-1:] + cols[:-1]
data_raw = data_raw[cols]
display(data_raw.head(3))
| session_name | age | sex | cycle | Left Ankle Angles_X | Left Ankle Angles_Y | Left Ankle Angles_Z | Left CenterOfMass_corr_X | Left CenterOfMass_corr_Y | Left CenterOfMass_corr_Z | Left Elbow Angles_X | Left Elbow Angles_Y | Left Elbow Angles_Z | Left Foot Pitch Angles_X | Left Foot Pitch Angles_Y | Left Foot Pitch Angles_Z | Left Foot Progression_X | Left Foot Progression_Y | Left Foot Progression_Z | Left Hip Angles_X | Left Hip Angles_Y | Left Hip Angles_Z | Left Knee Angles_X | Left Knee Angles_Y | Left Knee Angles_Z | Left Pelvic Angles_X | Left Pelvic Angles_Y | Left Pelvic Angles_Z | Left Shoulder Angles_X | Left Shoulder Angles_Y | Left Shoulder Angles_Z | Left Thorax Angles_X | Left Thorax Angles_Y | Left Thorax Angles_Z | Left Thorax_Lab Angles_X | Left Thorax_Lab Angles_Y | Left Thorax_Lab Angles_Z | Right Ankle Angles_X | Right Ankle Angles_Y | Right Ankle Angles_Z | Right CenterOfMass_corr_X | Right CenterOfMass_corr_Y | Right CenterOfMass_corr_Z | Right Elbow Angles_X | Right Elbow Angles_Y | Right Elbow Angles_Z | Right Foot Pitch Angles_X | Right Foot Pitch Angles_Y | Right Foot Pitch Angles_Z | Right Foot Progression_X | Right Foot Progression_Y | Right Foot Progression_Z | Right Hip Angles_X | Right Hip Angles_Y | Right Hip Angles_Z | Right Knee Angles_X | Right Knee Angles_Y | Right Knee Angles_Z | Right Pelvic Angles_X | Right Pelvic Angles_Y | Right Pelvic Angles_Z | Right Shoulder Angles_X | Right Shoulder Angles_Y | Right Shoulder Angles_Z | Right Thorax Angles_X | Right Thorax Angles_Y | Right Thorax Angles_Z | Right Thorax_Lab Angles_X | Right Thorax_Lab Angles_Y | Right Thorax_Lab Angles_Z | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | sub-002_2024-05-29 12:27:14 | 19 | male | cycle_1 | [-2.768442, -3.496685, -4.069201, -4.462231, -... | [7.935368, 8.292075, 8.685886, 9.081171, 9.438... | [0.6022689999999999, 0.629934, 0.660829, 0.667... | [0.0, -0.002226, -0.00464, -0.0072429999999999... | [0.0, -0.000268, -0.00040699999999999997, -0.0... | [0.0, -0.001321, -0.0026279999999999997, -0.00... | [22.490803, 22.188433, 21.868492, 21.554268, 2... | [-0.010251999999999999, -0.01488, -0.017207, -... | [104.20636, 103.466881, 102.595413, 101.67038,... | [14.214731, 13.85244, 13.169363, 12.26093, 11.... | [-3.179904, -3.617692, -4.20662, -4.875635, -5... | [19.032904, 18.2679, 17.341208, 16.317291, 15.... | [14.214735, 13.852443, 13.169367, 12.260934, 1... | [-3.179903, -3.61769, -4.206618, -4.875634, -5... | [-19.032904, -18.2679, -17.341208, -16.317291,... | [20.6766, 20.222738, 19.856031, 19.556093, 19.... | [-4.393382, -4.227548, -4.138162, -4.11265, -4... | [-0.280891, -1.543496, -2.484963, -3.042148, -... | [2.135388, 1.461672, 1.318257, 1.61889, 2.2538... | [-0.194929, -0.818533, -1.206393, -1.397462, -... | [-14.350356, -12.79824, -11.413175, -10.301519... | [3.991429, 3.891927, 3.795211, 3.698814, 3.599... | [-0.522435, -0.717429, -0.9033939999999999, -1... | [-4.647504, -4.119165, -3.555717, -2.977560999... | [2.708186, 2.7229010000000002, 2.779783, 2.872... | [-8.669715, -8.55989, -8.467485, -8.390489, -8... | [16.503153, 16.36441, 16.326563, 16.373936, 16... | [3.141546, 3.062634, 2.954683, 2.827985, 2.697... | [-1.945191, -1.821488, -1.716818, -1.627697, -... | [-2.919237, -3.102238, -3.3029479999999998, -3... | [7.254066, 7.08643, 6.894266, 6.684382, 6.4677... | [-2.086775, -2.146077, -2.215677, -2.286167, -... | [-7.7877279999999995, -7.421407, -7.040127, -6... | [-5.4825289999999995, -5.833911, -6.063342, -6... | [4.380441, 4.43795, 4.501745, 4.57004, 4.64473... | [2.074525, 3.183789, 4.164791, 4.954755, 5.505... | [0.0, 0.0015400000000000001, 0.003264, 0.00518... | [0.0, 0.000522, 0.001184, 0.001996, 0.002956, ... | [0.0, -0.000998, -0.0020039999999999997, -0.00... | [28.822729, 28.978001, 29.148277, 29.292614, 2... | [0.003681, 0.007298, 0.009037, 0.008681, 0.007... | [97.213318, 96.508972, 95.770767, 95.081993, 9... | [10.846501, 9.729046, 8.539912, 7.347647, 6.20... | [4.045348, 4.578834, 5.088349, 5.536511, 5.904... | [-13.617139, -12.353103, -10.997869, -9.585029... | [10.846505, 9.72905, 8.539916999999999, 7.3476... | [4.045347, 4.578834, 5.088348, 5.53651, 5.9045... | [-13.617139, -12.353103, -10.997869, -9.585029... | [18.893475, 19.141722, 19.27253, 19.289482, 19... | [-2.407445, -2.328364, -2.257316, -2.183507, -... | [-3.71332, -3.729039, -3.720179, -3.685369, -3... | [3.157154, 3.974303, 4.884859, 5.819755, 6.725... | [-0.8432679999999999, -0.736147, -0.5383439999... | [-15.543177, -15.934821, -16.170027, -16.20980... | [0.883767, 1.096725, 1.264113, 1.393558, 1.498... | [-0.365954, -0.43199099999999996, -0.456648999... | [4.044459, 4.620235, 5.210877, 5.8096049999999... | [-5.663974, -6.074879, -6.471653, -6.800329, -... | [-10.194664, -10.122283, -10.048591, -9.980521... | [8.034688, 7.914454, 7.848456, 7.869713, 7.992... | [7.059786, 6.799669, 6.570794, 6.367293, 6.177... | [-0.765127, -0.755408, -0.785045, -0.849924999... | [-4.957033, -5.034188, -5.127676, -5.235561, -... | [8.011764, 7.970358, 7.914809, 7.84646, 7.7663... | [-0.435287, -0.484906, -0.531291, -0.570133, -... | [-0.9707739999999999, -0.470057, 0.02395599999... |
| 1 | sub-002_2024-05-29 12:27:14 | 19 | male | cycle_10 | [-7.611665, -8.296784, -8.636337, -8.653679, -... | [6.345243, 6.41904, 6.53749, 6.682159, 6.82977... | [0.015241, 1.130675, 2.046575, 2.656725, 2.912... | [0.016838, 0.015413, 0.013836999999999999, 0.0... | [-0.009348, -0.009233, -0.008978, -0.008558, -... | [0.000483, -0.001058, -0.002572, -0.004029, -0... | [22.925438, 22.503048, 22.122334, 21.826393, 2... | [-0.019576, -0.016118999999999998, -0.01080699... | [102.04879, 101.1464, 100.409393, 99.857368, 9... | [12.341584, 12.335873, 12.083894, 11.615406, 1... | [-5.800112, -5.918781, -6.172589, -6.485641, -... | [26.423983, 25.275766, 23.834427, 22.164652, 2... | [12.341588, 12.335876, 12.083896, 11.615409, 1... | [-5.80011, -5.918779, -6.172588, -6.485639, -6... | [-26.423983, -25.275766, -23.834427, -22.16465... | [20.424934, 19.962803, 19.509947, 19.021643, 1... | [-1.797753, -1.688857, -1.708179, -1.814137999... | [-2.878228, -2.440343, -1.9427249999999998, -1... | [2.762847, 1.953323, 1.724262, 1.937817, 2.447... | [-0.236487, -0.9723259999999999, -1.39173, -1.... | [-18.446913, -19.233843, -19.68598, -19.714357... | [1.090251, 0.818932, 0.5202, 0.200094, -0.1280... | [-0.028227, -0.38070299999999996, -0.728997, -... | [-4.88499, -4.3810839999999995, -3.769862, -3.... | [3.134768, 3.105432, 3.044675, 2.940096, 2.789... | [-9.577947, -9.449483, -9.323687, -9.206041, -... | [18.54364, 18.746981, 18.860155, 18.887785, 18... | [4.173757, 4.379654, 4.612041, 4.860198, 5.105... | [-1.197732, -0.843793, -0.5067539999999999, -0... | [-2.178959, -2.3056609999999997, -2.478065, -2... | [5.30765, 5.244393, 5.181376, 5.113739, 5.0358... | [-1.020125, -1.011145, -1.010935, -1.014616, -... | [-7.164998, -6.754079, -6.28371, -5.74174, -5.... | [-5.628355, -5.81796, -5.811576, -5.627428, -5... | [3.403195, 3.262351, 3.242479, 3.355409, 3.601... | [-3.4450979999999998, -2.903352, -2.27135, -1.... | [-0.008744, -0.007463, -0.006006, -0.004360999... | [-0.012428999999999999, -0.01247, -0.012466, -... | [0.002436, 0.001673, 0.000812, -0.000139, -0.0... | [31.828133, 31.580637, 31.387247, 31.242411, 3... | [-0.0011359999999999999, 0.00102, 0.001377, 0.... | [90.242859, 90.552948, 91.0634, 91.718605, 92.... | [9.194756, 8.696275, 8.075597, 7.379229, 6.644... | [4.146616, 4.320601, 4.622187, 5.041539, 5.558... | [-21.498556, -20.357746, -19.085051, -17.74668... | [9.194759, 8.696277, 8.075601, 7.379233, 6.644... | [4.146615, 4.3206, 4.622186, 5.041538, 5.55823... | [-21.498556, -20.357746, -19.085051, -17.74668... | [18.847454, 18.874477, 18.864149, 18.821054, 1... | [-5.771685, -5.643616, -5.501174, -5.346303, -... | [-7.284528, -7.00043, -6.623891, -6.122418, -5... | [7.355994, 7.557222, 8.092501, 8.881246, 9.836... | [2.684278, 2.574164, 2.540403, 2.580063, 2.684... | [-10.199863, -10.37021, -10.588555, -10.881887... | [-1.597058, -1.574028, -1.5760939999999999, -1... | [-1.239754, -1.412793, -1.553985, -1.665565, -... | [0.427455, 0.937423, 1.446885, 1.9577300000000... | [-3.041614, -3.081234, -3.148071, -3.257658, -... | [-10.598993, -10.551416, -10.489378, -10.40383... | [10.954481, 10.959135, 10.947763, 10.889692, 1... | [8.854305, 8.837968, 8.840726, 8.842192, 8.807... | [-1.246204, -1.063445, -0.88775, -0.716517, -0... | [-5.760704, -5.845308, -5.895671, -5.918782, -... | [7.472655, 7.480461, 7.478623, 7.462115, 7.425... | [-1.736641, -1.716788, -1.6773069999999999, -1... | [-5.447077, -4.997944, -4.516113, -4.006367, -... |
| 2 | sub-002_2024-05-29 12:27:14 | 19 | male | cycle_11 | [-4.634666, -5.02648, -5.340784, -5.544887, -5... | [7.019156, 7.207421, 7.360003, 7.428865, 7.382... | [0.8590829999999999, 0.9442879999999999, 1.291... | [0.020581, 0.018594, 0.016465, 0.014183, 0.011... | [-0.041350000000000005, -0.04147000000000001, ... | [-0.002153, -0.003436, -0.004659, -0.005737, -... | [21.892666, 21.667431, 21.510859, 21.407866, 2... | [0.006052, 0.0068579999999999995, 0.0050479999... | [103.10041, 103.414948, 103.769325, 104.036888... | [13.801257, 13.364927, 12.614091, 11.681816, 1... | [-7.515644, -7.981946, -8.420057, -8.740601999... | [19.812559, 19.102261, 18.223965, 17.228157, 1... | [13.801261, 13.36493, 12.614095, 11.681819, 10... | [-7.515642, -7.981944, -8.420057, -8.740601999... | [-19.812559, -19.102261, -18.223965, -17.22815... | [21.490784, 21.206192, 20.810915, 20.239679, 1... | [-1.003782, -1.046244, -1.167168, -1.336805, -... | [1.397087, 2.274668, 2.839696, 3.0362, 2.89127... | [4.38941, 4.258365, 4.485112, 4.878491, 5.2717... | [1.547475, 1.425399, 1.3594089999999999, 1.309... | [-18.630737, -19.466963, -20.113201, -20.44350... | [1.38354, 1.262527, 1.063718, 0.784248, 0.4453... | [0.217076, -0.012333, -0.265026, -0.52868, -0.... | [-3.368124, -2.697494, -1.98144, -1.229187, -0... | [1.134091, 0.990254, 0.783542, 0.540679, 0.294... | [-10.493185, -10.305758, -10.112445, -9.931076... | [17.463486, 17.312468, 17.166122, 17.092224, 1... | [3.597836, 3.621159, 3.705733, 3.861809, 4.075... | [-1.154968, -0.9766429999999999, -0.759751, -0... | [-2.844895, -2.9492849999999997, -3.022531, -3... | [5.022056, 4.928493, 4.817266, 4.69589, 4.5721... | [-0.685429, -0.7331249999999999, -0.7694989999... | [-6.30276, -5.719454, -5.056945, -4.33003, -3.... | [-6.689217, -7.057434, -7.268858, -7.303356, -... | [4.002138, 4.135075, 4.315419, 4.513762, 4.705... | [-1.061941, -0.874251, -0.6129100000000001, -0... | [-0.010591, -0.008883, -0.006986999999999999, ... | [-0.013071, -0.014, -0.014787, -0.015403, -0.0... | [0.002051, 0.001292, 0.0005059999999999999, -0... | [33.081654, 32.856186, 32.723984, 32.691216, 3... | [0.005253, 0.00417, 0.002565, 0.001298, 0.0009... | [92.021736, 92.243813, 92.391449, 92.442314, 9... | [8.570474, 7.685758, 6.655786, 5.567538, 4.500... | [4.459383, 4.893425, 5.360785, 5.814991, 6.221... | [-14.959339, -13.506869, -12.001929, -10.49032... | [8.570477, 7.685762, 6.65579, 5.56754199999999... | [4.459382, 4.893425, 5.360784, 5.81499, 6.2219... | [-14.959339, -13.506869, -12.001929, -10.49032... | [16.00716, 16.189598, 16.405577, 16.658844, 16... | [-2.608135, -2.575316, -2.595477, -2.666513, -... | [-5.038992, -5.088731, -5.152913, -5.168307, -... | [3.29341, 3.71739, 4.438928, 5.395056, 6.51829... | [0.818074, 0.758255, 0.774924, 0.854435, 0.989... | [-9.693586, -9.014916, -8.393165, -7.926613, -... | [-1.286025, -1.090904, -0.854332, -0.579944999... | [-0.204015, -0.38847899999999996, -0.562819, -... | [1.125047, 1.8016709999999998, 2.525921, 3.273... | [-4.353323, -4.315862, -4.385724, -4.564363, -... | [-10.74014, -10.643498, -10.525556, -10.392888... | [13.31672, 13.247898, 13.045144, 12.724239, 12... | [8.859, 8.63991, 8.336942, 7.95187, 7.500204, ... | [0.103183, 0.21866000000000002, 0.307375, 0.37... | [-5.57998, -5.7986889999999995, -6.033959, -6.... | [7.547276, 7.527881, 7.468235, 7.366483, 7.225... | [0.632976, 0.589656, 0.5276609999999999, 0.448... | [-4.3928519999999995, -3.91817, -3.416698, -2.... |
df_age = data_raw.groupby(by=['sex', 'session_name']).agg({'age': 'first'}).reset_index().sort_values(by=['sex', 'age'])
display(df_age.describe())
| age | |
|---|---|
| count | 91.000000 |
| mean | 52.450549 |
| std | 23.708069 |
| min | 0.000000 |
| 25% | 29.000000 |
| 50% | 58.000000 |
| 75% | 72.500000 |
| max | 93.000000 |
As we observe, that we have 91 subjects. The maximum age is 93 years old and the minimum age is 0 years old, resulting that we have to drop out outliers.
We consider subjects younger then 5 years old as outlier.
print(f'Shape before filter: {data_raw.shape}')
data_raw = data_raw.loc[data_raw['age'] >= 5]
print(f'Shape after filter: {data_raw.shape}')
Shape before filter: (3656, 70) Shape after filter: (3471, 70)
Exploritory Data Analysis¶
Age Distribution for each Gender¶
display(data_raw.age.describe())
sns.violinplot(df_age, x='sex', y='age', hue='sex')
plt.title('Age Distribution for each Gender')
plt.show()
count 3471.000000 mean 57.238548 std 21.738467 min 6.000000 25% 39.000000 50% 63.000000 75% 74.000000 max 93.000000 Name: age, dtype: float64
The dataset consists of 3471 gait cylces with an average age of the subjects of approximately 57.2 years.
The age standard deviation is 21.7, reflecting a wide range of ages within the dataset.
Ages range from a minimum of 6 years to a maximum of 93 years.
The 25th percentile age is 39 years, the median age (50th percentile) is 63 years, and the 75th percentile is 74 years, indicating that most ages are concentrated between 39 and 74 years.
There are no substantial visual differences between the male and female age distributions, as both display similar shapes and spreads.
The female age distribution shows a slightly greater spread toward older ages compared to the male group.
Gait Cycles Females and Males¶
df_plot = data_raw.groupby(by=['sex']).agg({'session_name': 'count'}).reset_index()
fig = plt.figure(figsize=(10, 5))
sns.barplot(df_plot, x='sex', y='session_name', hue='sex')
# set title
plt.title('Number of Gait Cycles for Females and Males')
plt.ylabel('Number of gait cycles')
plt.xlabel('Sex')
plt.show()
We can observe that we have slightly more female subjects then male subjects in the data. This indicates that we have unbalanced target.
Angles in Gait Cylces¶
Next, we visually examine whether there are observable differences in walking patterns between males and females based on joint angles.
def get_mean_std(data: pd.DataFrame, joint_name: str, gender: str):
"""
Calculates the mean and standard deviation of gait cycle values for a specified joint and gender.
"""
matrix = []
for row in data.loc[data['sex'] == gender][joint_name]:
matrix.append(row)
matrix = np.array(matrix)
matrix.shape
std = np.std(matrix, axis=0)
mean = np.mean(matrix, axis=0)
x = np.arange(0, len(mean))
return x, mean, std
# Get all column names for the angles
colnames = data_raw.columns.tolist()[4:]
# groups of 3 columns for each joint (X, Y, Z)
joint_axes_groups = [
colnames[i:i + 3] for i in range(0, len(colnames), 3)
]
n_rows = len(joint_axes_groups) # number of joints
n_cols = 3 # X, Y, Z
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
# Plotting
for i, joint_axes in enumerate(joint_axes_groups):
for j, col in enumerate(joint_axes):
ax = axes[i, j]
# Females
color = 'mediumblue'
x, mean, std = get_mean_std(data_raw, col, 'female')
ax.plot(x, mean, color=color)
ax.fill_between(x, mean - std, mean + std, alpha=0.3, color=color)
# Males
color = 'tomato'
x, mean, std = get_mean_std(data_raw, col, 'male')
ax.plot(x, mean, color=color)
ax.fill_between(x, mean - std, mean + std, alpha=0.3, color=color)
# Titel and labels
ax.set_title(col)
ax.set_xlabel('Gait Cycle')
ax.set_ylabel('Angle')
plt.tight_layout()
plt.show()
Observation:
The plots reveal distinct differences in walking patterns between males and females, highlighted by variations in both average angles and standard deviations across nearly all joint features.
Males (in red) typically exhibit higher amplitude movements and more consistent patterns, particularly in the knee and ankle joints, suggesting greater range of motion and stability. In contrast, females (in blue) display more variability, especially in the knee and hip angles, along with slight asymmetries and timing differences. These patterns likely reflect underlying biomechanical differences, with males demonstrating a more stable, rhythmic gait and females showing subtle fluctuations and increased variability.
Correlations¶
def get_matrix(data: pd.DataFrame, joint_name: str):
"""
Calculates the mean and standard deviation of gait cycle values for a specified joint and gender.
"""
matrix = []
for row in data[joint_name]:
matrix.append(row)
matrix = np.array(matrix)
return matrix
# Get all column names for the angles
colnames = data_raw.columns.tolist()[4:]
arrays = []
for col in colnames:
matrix = get_matrix(data_raw, col)
arrays.append(matrix)
# Stack arrays along a new dimension, resulting in a 3D array
stacked = np.stack(arrays, axis=0)
print(f'Shape of stacked array: {stacked.shape}')
print(stacked.reshape(len(stacked), -1).shape)
# Calculate element-wise correlation across the arrays
corr = np.corrcoef(stacked.reshape(len(stacked), -1))
Shape of stacked array: (66, 3471, 101) (66, 350571)
mask = np.triu(np.ones_like(corr, dtype=bool)) # Source: https://stackoverflow.com/questions/2318529/plotting-only-upper-lower-triangle-of-a-heatmap
# Plot heatmap
plt.figure(figsize=(30, 30))
sns.heatmap(corr, annot=True, cmap='coolwarm', linewidth=0.5, fmt='.2f', annot_kws={"size": 8},
mask=mask, vmin=-1, vmax=1,
xticklabels=colnames, yticklabels=colnames
)
plt.title('Correlation Matrix of the Gait Cylces', fontsize=14, fontweight='bold')
plt.xticks(fontsize=12, rotation=90)
plt.yticks(fontsize=12, rotation=0)
plt.show()
Observations:
Strong Positive Correlations:
- Certain joint angles or measurements appear to be strongly correlated with others, indicated by dark red cells. For example, the left and right thorax angles along various axes seem to have high positive correlations, which could reflect symmetrical movement patterns.
- The hip angles and knee angles also show moderate to high positive correlations, suggesting that movement in these joints is coordinated during gait.
Symmetrical Patterns:
- The left and right counterparts for many body parts, like elbows, ankles, and shoulders, often show moderate to high correlations. This symmetry indicates that both sides of the body move in a coordinated manner, which is expected in a balanced gait cycle.
Negative Correlations:
- There are areas of strong negative correlation (dark blue cells), especially involving certain axes of shoulder, elbow, and pelvic movements. These could signify compensatory movements where one part of the body counterbalances the motion of another.
- For example, specific shoulder angles may counterbalance thorax or pelvic motions to maintain stability during gait.
Less Correlated Features:
- Some variables, especially those related to more distant body parts (e.g., left foot angles and right shoulder angles), show low or near-zero correlations, suggesting these movements are more independent from each other.
Potential Redundancies:
- High correlations between similar metrics (e.g., progression and center of mass coordinates) imply that some measurements could be redundant. This could be useful if you're considering dimensionality reduction techniques like PCA to reduce the number of features while retaining most of the information.
Baselines¶
Preprocessing for the Models¶
For the upcomeing steps we process the data further. This is nececarry that the data is compatible for the models.
Next we define data and the targets for the following taks. We also map 1 for female and 0 for male, as those are the target values for the prediction.
target = data_raw['sex'] # define the target
target = target.map({'female': 1, 'male': 0}) # assign 1 for female and 1 for male
data = data_raw.drop(columns=['sex', 'session_name', 'cycle']) # drop the column sex, cylce and session_name from the data
data['age'] = data['age'].astype('int') # convert the age to integer
print('Data Head')
display(data.head(4))
print('Target Head')
display(target.head(4))
print(f'Shape of the data: {data.shape}')
print(f'Shape of the target: {target.shape}')
Data Head
| age | Left Ankle Angles_X | Left Ankle Angles_Y | Left Ankle Angles_Z | Left CenterOfMass_corr_X | Left CenterOfMass_corr_Y | Left CenterOfMass_corr_Z | Left Elbow Angles_X | Left Elbow Angles_Y | Left Elbow Angles_Z | Left Foot Pitch Angles_X | Left Foot Pitch Angles_Y | Left Foot Pitch Angles_Z | Left Foot Progression_X | Left Foot Progression_Y | Left Foot Progression_Z | Left Hip Angles_X | Left Hip Angles_Y | Left Hip Angles_Z | Left Knee Angles_X | Left Knee Angles_Y | Left Knee Angles_Z | Left Pelvic Angles_X | Left Pelvic Angles_Y | Left Pelvic Angles_Z | Left Shoulder Angles_X | Left Shoulder Angles_Y | Left Shoulder Angles_Z | Left Thorax Angles_X | Left Thorax Angles_Y | Left Thorax Angles_Z | Left Thorax_Lab Angles_X | Left Thorax_Lab Angles_Y | Left Thorax_Lab Angles_Z | Right Ankle Angles_X | Right Ankle Angles_Y | Right Ankle Angles_Z | Right CenterOfMass_corr_X | Right CenterOfMass_corr_Y | Right CenterOfMass_corr_Z | Right Elbow Angles_X | Right Elbow Angles_Y | Right Elbow Angles_Z | Right Foot Pitch Angles_X | Right Foot Pitch Angles_Y | Right Foot Pitch Angles_Z | Right Foot Progression_X | Right Foot Progression_Y | Right Foot Progression_Z | Right Hip Angles_X | Right Hip Angles_Y | Right Hip Angles_Z | Right Knee Angles_X | Right Knee Angles_Y | Right Knee Angles_Z | Right Pelvic Angles_X | Right Pelvic Angles_Y | Right Pelvic Angles_Z | Right Shoulder Angles_X | Right Shoulder Angles_Y | Right Shoulder Angles_Z | Right Thorax Angles_X | Right Thorax Angles_Y | Right Thorax Angles_Z | Right Thorax_Lab Angles_X | Right Thorax_Lab Angles_Y | Right Thorax_Lab Angles_Z | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 19 | [-2.768442, -3.496685, -4.069201, -4.462231, -... | [7.935368, 8.292075, 8.685886, 9.081171, 9.438... | [0.6022689999999999, 0.629934, 0.660829, 0.667... | [0.0, -0.002226, -0.00464, -0.0072429999999999... | [0.0, -0.000268, -0.00040699999999999997, -0.0... | [0.0, -0.001321, -0.0026279999999999997, -0.00... | [22.490803, 22.188433, 21.868492, 21.554268, 2... | [-0.010251999999999999, -0.01488, -0.017207, -... | [104.20636, 103.466881, 102.595413, 101.67038,... | [14.214731, 13.85244, 13.169363, 12.26093, 11.... | [-3.179904, -3.617692, -4.20662, -4.875635, -5... | [19.032904, 18.2679, 17.341208, 16.317291, 15.... | [14.214735, 13.852443, 13.169367, 12.260934, 1... | [-3.179903, -3.61769, -4.206618, -4.875634, -5... | [-19.032904, -18.2679, -17.341208, -16.317291,... | [20.6766, 20.222738, 19.856031, 19.556093, 19.... | [-4.393382, -4.227548, -4.138162, -4.11265, -4... | [-0.280891, -1.543496, -2.484963, -3.042148, -... | [2.135388, 1.461672, 1.318257, 1.61889, 2.2538... | [-0.194929, -0.818533, -1.206393, -1.397462, -... | [-14.350356, -12.79824, -11.413175, -10.301519... | [3.991429, 3.891927, 3.795211, 3.698814, 3.599... | [-0.522435, -0.717429, -0.9033939999999999, -1... | [-4.647504, -4.119165, -3.555717, -2.977560999... | [2.708186, 2.7229010000000002, 2.779783, 2.872... | [-8.669715, -8.55989, -8.467485, -8.390489, -8... | [16.503153, 16.36441, 16.326563, 16.373936, 16... | [3.141546, 3.062634, 2.954683, 2.827985, 2.697... | [-1.945191, -1.821488, -1.716818, -1.627697, -... | [-2.919237, -3.102238, -3.3029479999999998, -3... | [7.254066, 7.08643, 6.894266, 6.684382, 6.4677... | [-2.086775, -2.146077, -2.215677, -2.286167, -... | [-7.7877279999999995, -7.421407, -7.040127, -6... | [-5.4825289999999995, -5.833911, -6.063342, -6... | [4.380441, 4.43795, 4.501745, 4.57004, 4.64473... | [2.074525, 3.183789, 4.164791, 4.954755, 5.505... | [0.0, 0.0015400000000000001, 0.003264, 0.00518... | [0.0, 0.000522, 0.001184, 0.001996, 0.002956, ... | [0.0, -0.000998, -0.0020039999999999997, -0.00... | [28.822729, 28.978001, 29.148277, 29.292614, 2... | [0.003681, 0.007298, 0.009037, 0.008681, 0.007... | [97.213318, 96.508972, 95.770767, 95.081993, 9... | [10.846501, 9.729046, 8.539912, 7.347647, 6.20... | [4.045348, 4.578834, 5.088349, 5.536511, 5.904... | [-13.617139, -12.353103, -10.997869, -9.585029... | [10.846505, 9.72905, 8.539916999999999, 7.3476... | [4.045347, 4.578834, 5.088348, 5.53651, 5.9045... | [-13.617139, -12.353103, -10.997869, -9.585029... | [18.893475, 19.141722, 19.27253, 19.289482, 19... | [-2.407445, -2.328364, -2.257316, -2.183507, -... | [-3.71332, -3.729039, -3.720179, -3.685369, -3... | [3.157154, 3.974303, 4.884859, 5.819755, 6.725... | [-0.8432679999999999, -0.736147, -0.5383439999... | [-15.543177, -15.934821, -16.170027, -16.20980... | [0.883767, 1.096725, 1.264113, 1.393558, 1.498... | [-0.365954, -0.43199099999999996, -0.456648999... | [4.044459, 4.620235, 5.210877, 5.8096049999999... | [-5.663974, -6.074879, -6.471653, -6.800329, -... | [-10.194664, -10.122283, -10.048591, -9.980521... | [8.034688, 7.914454, 7.848456, 7.869713, 7.992... | [7.059786, 6.799669, 6.570794, 6.367293, 6.177... | [-0.765127, -0.755408, -0.785045, -0.849924999... | [-4.957033, -5.034188, -5.127676, -5.235561, -... | [8.011764, 7.970358, 7.914809, 7.84646, 7.7663... | [-0.435287, -0.484906, -0.531291, -0.570133, -... | [-0.9707739999999999, -0.470057, 0.02395599999... |
| 1 | 19 | [-7.611665, -8.296784, -8.636337, -8.653679, -... | [6.345243, 6.41904, 6.53749, 6.682159, 6.82977... | [0.015241, 1.130675, 2.046575, 2.656725, 2.912... | [0.016838, 0.015413, 0.013836999999999999, 0.0... | [-0.009348, -0.009233, -0.008978, -0.008558, -... | [0.000483, -0.001058, -0.002572, -0.004029, -0... | [22.925438, 22.503048, 22.122334, 21.826393, 2... | [-0.019576, -0.016118999999999998, -0.01080699... | [102.04879, 101.1464, 100.409393, 99.857368, 9... | [12.341584, 12.335873, 12.083894, 11.615406, 1... | [-5.800112, -5.918781, -6.172589, -6.485641, -... | [26.423983, 25.275766, 23.834427, 22.164652, 2... | [12.341588, 12.335876, 12.083896, 11.615409, 1... | [-5.80011, -5.918779, -6.172588, -6.485639, -6... | [-26.423983, -25.275766, -23.834427, -22.16465... | [20.424934, 19.962803, 19.509947, 19.021643, 1... | [-1.797753, -1.688857, -1.708179, -1.814137999... | [-2.878228, -2.440343, -1.9427249999999998, -1... | [2.762847, 1.953323, 1.724262, 1.937817, 2.447... | [-0.236487, -0.9723259999999999, -1.39173, -1.... | [-18.446913, -19.233843, -19.68598, -19.714357... | [1.090251, 0.818932, 0.5202, 0.200094, -0.1280... | [-0.028227, -0.38070299999999996, -0.728997, -... | [-4.88499, -4.3810839999999995, -3.769862, -3.... | [3.134768, 3.105432, 3.044675, 2.940096, 2.789... | [-9.577947, -9.449483, -9.323687, -9.206041, -... | [18.54364, 18.746981, 18.860155, 18.887785, 18... | [4.173757, 4.379654, 4.612041, 4.860198, 5.105... | [-1.197732, -0.843793, -0.5067539999999999, -0... | [-2.178959, -2.3056609999999997, -2.478065, -2... | [5.30765, 5.244393, 5.181376, 5.113739, 5.0358... | [-1.020125, -1.011145, -1.010935, -1.014616, -... | [-7.164998, -6.754079, -6.28371, -5.74174, -5.... | [-5.628355, -5.81796, -5.811576, -5.627428, -5... | [3.403195, 3.262351, 3.242479, 3.355409, 3.601... | [-3.4450979999999998, -2.903352, -2.27135, -1.... | [-0.008744, -0.007463, -0.006006, -0.004360999... | [-0.012428999999999999, -0.01247, -0.012466, -... | [0.002436, 0.001673, 0.000812, -0.000139, -0.0... | [31.828133, 31.580637, 31.387247, 31.242411, 3... | [-0.0011359999999999999, 0.00102, 0.001377, 0.... | [90.242859, 90.552948, 91.0634, 91.718605, 92.... | [9.194756, 8.696275, 8.075597, 7.379229, 6.644... | [4.146616, 4.320601, 4.622187, 5.041539, 5.558... | [-21.498556, -20.357746, -19.085051, -17.74668... | [9.194759, 8.696277, 8.075601, 7.379233, 6.644... | [4.146615, 4.3206, 4.622186, 5.041538, 5.55823... | [-21.498556, -20.357746, -19.085051, -17.74668... | [18.847454, 18.874477, 18.864149, 18.821054, 1... | [-5.771685, -5.643616, -5.501174, -5.346303, -... | [-7.284528, -7.00043, -6.623891, -6.122418, -5... | [7.355994, 7.557222, 8.092501, 8.881246, 9.836... | [2.684278, 2.574164, 2.540403, 2.580063, 2.684... | [-10.199863, -10.37021, -10.588555, -10.881887... | [-1.597058, -1.574028, -1.5760939999999999, -1... | [-1.239754, -1.412793, -1.553985, -1.665565, -... | [0.427455, 0.937423, 1.446885, 1.9577300000000... | [-3.041614, -3.081234, -3.148071, -3.257658, -... | [-10.598993, -10.551416, -10.489378, -10.40383... | [10.954481, 10.959135, 10.947763, 10.889692, 1... | [8.854305, 8.837968, 8.840726, 8.842192, 8.807... | [-1.246204, -1.063445, -0.88775, -0.716517, -0... | [-5.760704, -5.845308, -5.895671, -5.918782, -... | [7.472655, 7.480461, 7.478623, 7.462115, 7.425... | [-1.736641, -1.716788, -1.6773069999999999, -1... | [-5.447077, -4.997944, -4.516113, -4.006367, -... |
| 2 | 19 | [-4.634666, -5.02648, -5.340784, -5.544887, -5... | [7.019156, 7.207421, 7.360003, 7.428865, 7.382... | [0.8590829999999999, 0.9442879999999999, 1.291... | [0.020581, 0.018594, 0.016465, 0.014183, 0.011... | [-0.041350000000000005, -0.04147000000000001, ... | [-0.002153, -0.003436, -0.004659, -0.005737, -... | [21.892666, 21.667431, 21.510859, 21.407866, 2... | [0.006052, 0.0068579999999999995, 0.0050479999... | [103.10041, 103.414948, 103.769325, 104.036888... | [13.801257, 13.364927, 12.614091, 11.681816, 1... | [-7.515644, -7.981946, -8.420057, -8.740601999... | [19.812559, 19.102261, 18.223965, 17.228157, 1... | [13.801261, 13.36493, 12.614095, 11.681819, 10... | [-7.515642, -7.981944, -8.420057, -8.740601999... | [-19.812559, -19.102261, -18.223965, -17.22815... | [21.490784, 21.206192, 20.810915, 20.239679, 1... | [-1.003782, -1.046244, -1.167168, -1.336805, -... | [1.397087, 2.274668, 2.839696, 3.0362, 2.89127... | [4.38941, 4.258365, 4.485112, 4.878491, 5.2717... | [1.547475, 1.425399, 1.3594089999999999, 1.309... | [-18.630737, -19.466963, -20.113201, -20.44350... | [1.38354, 1.262527, 1.063718, 0.784248, 0.4453... | [0.217076, -0.012333, -0.265026, -0.52868, -0.... | [-3.368124, -2.697494, -1.98144, -1.229187, -0... | [1.134091, 0.990254, 0.783542, 0.540679, 0.294... | [-10.493185, -10.305758, -10.112445, -9.931076... | [17.463486, 17.312468, 17.166122, 17.092224, 1... | [3.597836, 3.621159, 3.705733, 3.861809, 4.075... | [-1.154968, -0.9766429999999999, -0.759751, -0... | [-2.844895, -2.9492849999999997, -3.022531, -3... | [5.022056, 4.928493, 4.817266, 4.69589, 4.5721... | [-0.685429, -0.7331249999999999, -0.7694989999... | [-6.30276, -5.719454, -5.056945, -4.33003, -3.... | [-6.689217, -7.057434, -7.268858, -7.303356, -... | [4.002138, 4.135075, 4.315419, 4.513762, 4.705... | [-1.061941, -0.874251, -0.6129100000000001, -0... | [-0.010591, -0.008883, -0.006986999999999999, ... | [-0.013071, -0.014, -0.014787, -0.015403, -0.0... | [0.002051, 0.001292, 0.0005059999999999999, -0... | [33.081654, 32.856186, 32.723984, 32.691216, 3... | [0.005253, 0.00417, 0.002565, 0.001298, 0.0009... | [92.021736, 92.243813, 92.391449, 92.442314, 9... | [8.570474, 7.685758, 6.655786, 5.567538, 4.500... | [4.459383, 4.893425, 5.360785, 5.814991, 6.221... | [-14.959339, -13.506869, -12.001929, -10.49032... | [8.570477, 7.685762, 6.65579, 5.56754199999999... | [4.459382, 4.893425, 5.360784, 5.81499, 6.2219... | [-14.959339, -13.506869, -12.001929, -10.49032... | [16.00716, 16.189598, 16.405577, 16.658844, 16... | [-2.608135, -2.575316, -2.595477, -2.666513, -... | [-5.038992, -5.088731, -5.152913, -5.168307, -... | [3.29341, 3.71739, 4.438928, 5.395056, 6.51829... | [0.818074, 0.758255, 0.774924, 0.854435, 0.989... | [-9.693586, -9.014916, -8.393165, -7.926613, -... | [-1.286025, -1.090904, -0.854332, -0.579944999... | [-0.204015, -0.38847899999999996, -0.562819, -... | [1.125047, 1.8016709999999998, 2.525921, 3.273... | [-4.353323, -4.315862, -4.385724, -4.564363, -... | [-10.74014, -10.643498, -10.525556, -10.392888... | [13.31672, 13.247898, 13.045144, 12.724239, 12... | [8.859, 8.63991, 8.336942, 7.95187, 7.500204, ... | [0.103183, 0.21866000000000002, 0.307375, 0.37... | [-5.57998, -5.7986889999999995, -6.033959, -6.... | [7.547276, 7.527881, 7.468235, 7.366483, 7.225... | [0.632976, 0.589656, 0.5276609999999999, 0.448... | [-4.3928519999999995, -3.91817, -3.416698, -2.... |
| 3 | 19 | [-5.126583, -5.563317, -5.8677600000000005, -6... | [5.791751, 5.315883, 4.801734, 4.333135, 4.000... | [-2.273809, -2.451972, -2.259908, -1.800332, -... | [0.035844999999999995, 0.034800000000000005, 0... | [-0.069149, -0.068253, -0.067149, -0.065811, -... | [-0.002705, -0.003802, -0.00491, -0.005979, -0... | [22.576475, 22.534582, 22.476358, 22.393591, 2... | [-0.0009549999999999999, -0.001529, -0.0022589... | [103.177101, 103.719116, 104.341812, 104.99225... | [14.557133, 13.954084, 13.133732, 12.212157, 1... | [-7.865427, -7.826179, -7.720717, -7.581029, -... | [24.515747, 23.884928, 22.725981, 21.184366, 1... | [14.557137, 13.954086, 13.133736, 12.21216, 11... | [-7.865425, -7.8261769999999995, -7.7207159999... | [-24.515747, -23.884928, -22.725981, -21.18436... | [21.77968, 21.500816, 21.181494, 20.783443, 20... | [0.670709, 0.961919, 1.197635, 1.38975, 1.5488... | [-0.759508, -0.747437, -0.6175849999999999, -0... | [3.097183, 2.726982, 2.64418, 2.753467, 2.9566... | [1.63921, 1.446585, 1.33567, 1.25539, 1.164627... | [-19.076433, -18.702896, -18.339506, -17.93129... | [1.622077, 1.6338110000000001, 1.631074, 1.605... | [1.004183, 0.927782, 0.858348, 0.7985559999999... | [-2.789881, -2.39612, -1.94903, -1.445485, -0.... | [-2.787915, -2.747472, -2.713462, -2.687996, -... | [-11.227366, -11.166797, -11.097892, -11.02633... | [21.813576, 21.836899, 21.778131, 21.662382, 2... | [3.321212, 3.2430149999999998, 3.172034, 3.113... | [0.279766, 0.2356, 0.191217, 0.143678, 0.09167... | [-3.42988, -3.314497, -3.187124, -3.06637, -2.... | [4.858467, 4.802083, 4.737855, 4.662281, 4.571... | [1.576031, 1.442243, 1.314025, 1.192567, 1.078... | [-6.184168, -5.679617, -5.109772, -4.49029, -3... | [-8.791668, -8.750483, -8.691199, -8.628144, -... | [4.229726, 4.189322, 4.152287, 4.138055, 4.176... | [-0.8796600000000001, -0.09082000000000001, 0.... | [-0.0068379999999999995, -0.005461, -0.003916,... | [-0.059211, -0.060008, -0.060662, -0.061166, -... | [0.004118999999999999, 0.0035989999999999998, ... | [35.939461, 35.752632, 35.566681, 35.418114, 3... | [0.002263, 0.002435, 0.0022619999999999997, 0.... | [93.980476, 94.029068, 94.107063, 94.192963, 9... | [7.410515, 6.6469059999999995, 5.728951, 4.689... | [5.041845, 5.388712, 5.760208, 6.156972, 6.580... | [-13.847506, -12.840821, -11.775722, -10.68319... | [7.41052, 6.646911, 5.728954, 4.689843, 3.5878... | [5.041844, 5.388711, 5.760208, 6.156972, 6.580... | [-13.847506, -12.840821, -11.775722, -10.68319... | [19.223148, 19.77387, 20.344336, 20.890179, 21... | [-3.799106, -3.6458969999999997, -3.463251, -3... | [-4.025812, -3.748768, -3.553308, -3.454549, -... | [5.193684, 6.295167, 7.574003, 8.958694, 10.36... | [0.979599, 1.074848, 1.164345, 1.227657, 1.248... | [-10.921241, -11.355009, -11.929684, -12.61471... | [-0.868459, -0.653574, -0.407782, -0.149411, 0... | [-0.735085, -0.7949849999999999, -0.8337979999... | [2.640687, 3.018044, 3.416568, 3.837621, 4.279... | [-9.417412, -9.185868, -8.856594, -8.463542, -... | [-11.225997, -11.193144, -11.140943, -11.07178... | [9.740415, 9.96048, 10.225911, 10.527129, 10.8... | [7.465795, 7.240416, 7.00197, 6.767741, 6.5406... | [0.6862579999999999, 0.7500279999999999, 0.801... | [-7.127439, -7.18489, -7.231261, -7.276356, -7... | [6.552842, 6.541142, 6.546217, 6.567391, 6.598... | [0.7637649999999999, 0.772443, 0.790731, 0.815... | [-4.36111, -4.0338, -3.675273, -3.293513, -2.8... |
Target Head
0 0 1 0 2 0 3 0 Name: sex, dtype: int64
Shape of the data: (3471, 67) Shape of the target: (3471,)
Next we define the gait_data and the age_data because we have to handle them seperatly for the upcoming steps.
gait_data = data.drop(columns=['age']) # drop the age column
age_data = data['age']
display(gait_data.head(2))
| Left Ankle Angles_X | Left Ankle Angles_Y | Left Ankle Angles_Z | Left CenterOfMass_corr_X | Left CenterOfMass_corr_Y | Left CenterOfMass_corr_Z | Left Elbow Angles_X | Left Elbow Angles_Y | Left Elbow Angles_Z | Left Foot Pitch Angles_X | Left Foot Pitch Angles_Y | Left Foot Pitch Angles_Z | Left Foot Progression_X | Left Foot Progression_Y | Left Foot Progression_Z | Left Hip Angles_X | Left Hip Angles_Y | Left Hip Angles_Z | Left Knee Angles_X | Left Knee Angles_Y | Left Knee Angles_Z | Left Pelvic Angles_X | Left Pelvic Angles_Y | Left Pelvic Angles_Z | Left Shoulder Angles_X | Left Shoulder Angles_Y | Left Shoulder Angles_Z | Left Thorax Angles_X | Left Thorax Angles_Y | Left Thorax Angles_Z | Left Thorax_Lab Angles_X | Left Thorax_Lab Angles_Y | Left Thorax_Lab Angles_Z | Right Ankle Angles_X | Right Ankle Angles_Y | Right Ankle Angles_Z | Right CenterOfMass_corr_X | Right CenterOfMass_corr_Y | Right CenterOfMass_corr_Z | Right Elbow Angles_X | Right Elbow Angles_Y | Right Elbow Angles_Z | Right Foot Pitch Angles_X | Right Foot Pitch Angles_Y | Right Foot Pitch Angles_Z | Right Foot Progression_X | Right Foot Progression_Y | Right Foot Progression_Z | Right Hip Angles_X | Right Hip Angles_Y | Right Hip Angles_Z | Right Knee Angles_X | Right Knee Angles_Y | Right Knee Angles_Z | Right Pelvic Angles_X | Right Pelvic Angles_Y | Right Pelvic Angles_Z | Right Shoulder Angles_X | Right Shoulder Angles_Y | Right Shoulder Angles_Z | Right Thorax Angles_X | Right Thorax Angles_Y | Right Thorax Angles_Z | Right Thorax_Lab Angles_X | Right Thorax_Lab Angles_Y | Right Thorax_Lab Angles_Z | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | [-2.768442, -3.496685, -4.069201, -4.462231, -... | [7.935368, 8.292075, 8.685886, 9.081171, 9.438... | [0.6022689999999999, 0.629934, 0.660829, 0.667... | [0.0, -0.002226, -0.00464, -0.0072429999999999... | [0.0, -0.000268, -0.00040699999999999997, -0.0... | [0.0, -0.001321, -0.0026279999999999997, -0.00... | [22.490803, 22.188433, 21.868492, 21.554268, 2... | [-0.010251999999999999, -0.01488, -0.017207, -... | [104.20636, 103.466881, 102.595413, 101.67038,... | [14.214731, 13.85244, 13.169363, 12.26093, 11.... | [-3.179904, -3.617692, -4.20662, -4.875635, -5... | [19.032904, 18.2679, 17.341208, 16.317291, 15.... | [14.214735, 13.852443, 13.169367, 12.260934, 1... | [-3.179903, -3.61769, -4.206618, -4.875634, -5... | [-19.032904, -18.2679, -17.341208, -16.317291,... | [20.6766, 20.222738, 19.856031, 19.556093, 19.... | [-4.393382, -4.227548, -4.138162, -4.11265, -4... | [-0.280891, -1.543496, -2.484963, -3.042148, -... | [2.135388, 1.461672, 1.318257, 1.61889, 2.2538... | [-0.194929, -0.818533, -1.206393, -1.397462, -... | [-14.350356, -12.79824, -11.413175, -10.301519... | [3.991429, 3.891927, 3.795211, 3.698814, 3.599... | [-0.522435, -0.717429, -0.9033939999999999, -1... | [-4.647504, -4.119165, -3.555717, -2.977560999... | [2.708186, 2.7229010000000002, 2.779783, 2.872... | [-8.669715, -8.55989, -8.467485, -8.390489, -8... | [16.503153, 16.36441, 16.326563, 16.373936, 16... | [3.141546, 3.062634, 2.954683, 2.827985, 2.697... | [-1.945191, -1.821488, -1.716818, -1.627697, -... | [-2.919237, -3.102238, -3.3029479999999998, -3... | [7.254066, 7.08643, 6.894266, 6.684382, 6.4677... | [-2.086775, -2.146077, -2.215677, -2.286167, -... | [-7.7877279999999995, -7.421407, -7.040127, -6... | [-5.4825289999999995, -5.833911, -6.063342, -6... | [4.380441, 4.43795, 4.501745, 4.57004, 4.64473... | [2.074525, 3.183789, 4.164791, 4.954755, 5.505... | [0.0, 0.0015400000000000001, 0.003264, 0.00518... | [0.0, 0.000522, 0.001184, 0.001996, 0.002956, ... | [0.0, -0.000998, -0.0020039999999999997, -0.00... | [28.822729, 28.978001, 29.148277, 29.292614, 2... | [0.003681, 0.007298, 0.009037, 0.008681, 0.007... | [97.213318, 96.508972, 95.770767, 95.081993, 9... | [10.846501, 9.729046, 8.539912, 7.347647, 6.20... | [4.045348, 4.578834, 5.088349, 5.536511, 5.904... | [-13.617139, -12.353103, -10.997869, -9.585029... | [10.846505, 9.72905, 8.539916999999999, 7.3476... | [4.045347, 4.578834, 5.088348, 5.53651, 5.9045... | [-13.617139, -12.353103, -10.997869, -9.585029... | [18.893475, 19.141722, 19.27253, 19.289482, 19... | [-2.407445, -2.328364, -2.257316, -2.183507, -... | [-3.71332, -3.729039, -3.720179, -3.685369, -3... | [3.157154, 3.974303, 4.884859, 5.819755, 6.725... | [-0.8432679999999999, -0.736147, -0.5383439999... | [-15.543177, -15.934821, -16.170027, -16.20980... | [0.883767, 1.096725, 1.264113, 1.393558, 1.498... | [-0.365954, -0.43199099999999996, -0.456648999... | [4.044459, 4.620235, 5.210877, 5.8096049999999... | [-5.663974, -6.074879, -6.471653, -6.800329, -... | [-10.194664, -10.122283, -10.048591, -9.980521... | [8.034688, 7.914454, 7.848456, 7.869713, 7.992... | [7.059786, 6.799669, 6.570794, 6.367293, 6.177... | [-0.765127, -0.755408, -0.785045, -0.849924999... | [-4.957033, -5.034188, -5.127676, -5.235561, -... | [8.011764, 7.970358, 7.914809, 7.84646, 7.7663... | [-0.435287, -0.484906, -0.531291, -0.570133, -... | [-0.9707739999999999, -0.470057, 0.02395599999... |
| 1 | [-7.611665, -8.296784, -8.636337, -8.653679, -... | [6.345243, 6.41904, 6.53749, 6.682159, 6.82977... | [0.015241, 1.130675, 2.046575, 2.656725, 2.912... | [0.016838, 0.015413, 0.013836999999999999, 0.0... | [-0.009348, -0.009233, -0.008978, -0.008558, -... | [0.000483, -0.001058, -0.002572, -0.004029, -0... | [22.925438, 22.503048, 22.122334, 21.826393, 2... | [-0.019576, -0.016118999999999998, -0.01080699... | [102.04879, 101.1464, 100.409393, 99.857368, 9... | [12.341584, 12.335873, 12.083894, 11.615406, 1... | [-5.800112, -5.918781, -6.172589, -6.485641, -... | [26.423983, 25.275766, 23.834427, 22.164652, 2... | [12.341588, 12.335876, 12.083896, 11.615409, 1... | [-5.80011, -5.918779, -6.172588, -6.485639, -6... | [-26.423983, -25.275766, -23.834427, -22.16465... | [20.424934, 19.962803, 19.509947, 19.021643, 1... | [-1.797753, -1.688857, -1.708179, -1.814137999... | [-2.878228, -2.440343, -1.9427249999999998, -1... | [2.762847, 1.953323, 1.724262, 1.937817, 2.447... | [-0.236487, -0.9723259999999999, -1.39173, -1.... | [-18.446913, -19.233843, -19.68598, -19.714357... | [1.090251, 0.818932, 0.5202, 0.200094, -0.1280... | [-0.028227, -0.38070299999999996, -0.728997, -... | [-4.88499, -4.3810839999999995, -3.769862, -3.... | [3.134768, 3.105432, 3.044675, 2.940096, 2.789... | [-9.577947, -9.449483, -9.323687, -9.206041, -... | [18.54364, 18.746981, 18.860155, 18.887785, 18... | [4.173757, 4.379654, 4.612041, 4.860198, 5.105... | [-1.197732, -0.843793, -0.5067539999999999, -0... | [-2.178959, -2.3056609999999997, -2.478065, -2... | [5.30765, 5.244393, 5.181376, 5.113739, 5.0358... | [-1.020125, -1.011145, -1.010935, -1.014616, -... | [-7.164998, -6.754079, -6.28371, -5.74174, -5.... | [-5.628355, -5.81796, -5.811576, -5.627428, -5... | [3.403195, 3.262351, 3.242479, 3.355409, 3.601... | [-3.4450979999999998, -2.903352, -2.27135, -1.... | [-0.008744, -0.007463, -0.006006, -0.004360999... | [-0.012428999999999999, -0.01247, -0.012466, -... | [0.002436, 0.001673, 0.000812, -0.000139, -0.0... | [31.828133, 31.580637, 31.387247, 31.242411, 3... | [-0.0011359999999999999, 0.00102, 0.001377, 0.... | [90.242859, 90.552948, 91.0634, 91.718605, 92.... | [9.194756, 8.696275, 8.075597, 7.379229, 6.644... | [4.146616, 4.320601, 4.622187, 5.041539, 5.558... | [-21.498556, -20.357746, -19.085051, -17.74668... | [9.194759, 8.696277, 8.075601, 7.379233, 6.644... | [4.146615, 4.3206, 4.622186, 5.041538, 5.55823... | [-21.498556, -20.357746, -19.085051, -17.74668... | [18.847454, 18.874477, 18.864149, 18.821054, 1... | [-5.771685, -5.643616, -5.501174, -5.346303, -... | [-7.284528, -7.00043, -6.623891, -6.122418, -5... | [7.355994, 7.557222, 8.092501, 8.881246, 9.836... | [2.684278, 2.574164, 2.540403, 2.580063, 2.684... | [-10.199863, -10.37021, -10.588555, -10.881887... | [-1.597058, -1.574028, -1.5760939999999999, -1... | [-1.239754, -1.412793, -1.553985, -1.665565, -... | [0.427455, 0.937423, 1.446885, 1.9577300000000... | [-3.041614, -3.081234, -3.148071, -3.257658, -... | [-10.598993, -10.551416, -10.489378, -10.40383... | [10.954481, 10.959135, 10.947763, 10.889692, 1... | [8.854305, 8.837968, 8.840726, 8.842192, 8.807... | [-1.246204, -1.063445, -0.88775, -0.716517, -0... | [-5.760704, -5.845308, -5.895671, -5.918782, -... | [7.472655, 7.480461, 7.478623, 7.462115, 7.425... | [-1.736641, -1.716788, -1.6773069999999999, -1... | [-5.447077, -4.997944, -4.516113, -4.006367, -... |
Next we reshape the gait_data to (gait cycles, 101 angles, features)
expanded_features = []
# Convert each column to a numpy array and append it to the list
for col in gait_data.columns:
expanded_col = np.array(gait_data[col].to_list())
expanded_features.append(expanded_col)
# Stack the arrays in the list to create a 3D numpy array
expanded_data = np.hstack(expanded_features)
gait_data = expanded_data.reshape(len(expanded_data), -1, len(gait_data.columns))
# sclicing: [gait cylces, 101 angles, features]
print(gait_data[:,:,:].shape)
(3471, 101, 66)
Next we concatinate the age data back to the data so we can fit all in the scaler.
age_data = age_data.to_numpy().reshape(-1, 1)
# Reshape von age_data
age_data_expanded = np.repeat(age_data[:, np.newaxis, :], gait_data.shape[1], axis=1) # Shape (3916, 101, 1)
print(f'Shape of the age data: {age_data_expanded.shape}')
# Concatenation von gait_data und age_data
combined_data = np.concatenate([gait_data, age_data_expanded], axis=-1) # Shape (3916, 101, 67)
print(f'Shape of the combined data: {combined_data.shape}')
Shape of the age data: (3471, 101, 1) Shape of the combined data: (3471, 101, 67)
Next we split the data into train, test and validation data. During the training of our Deep Learning model we use the validation data to track stage of the model. At the end we compare the results with the unseen test data.
# reshape the gait_data for train, test split
data = combined_data.reshape(len(gait_data), -1)
target = target.to_numpy().reshape(-1, 1)
print(f'Shape of the data: {data.shape}')
print(f'Shape of the target: {target.shape}')
# Split the data into train, test and validation set
data_train, data_temp, target_train, target_temp = train_test_split(data, target, test_size=0.4, random_state=1, stratify=target)
data_val, data_test, target_val, target_test = train_test_split(data_temp, target_temp, test_size=0.5, random_state=1, stratify=target_temp)
# scale the data
scaler = MinMaxScaler()
X_train = scaler.fit_transform(data_train)
X_val = scaler.transform(data_val)
X_test = scaler.transform(data_test)
print(f'Shape of the train data: {data_train.shape}')
print(f'Shape of the validation data: {data_val.shape}')
print(f'Shape of the test data: {data_test.shape}')
print(f'Shape of the train target: {target_train.shape}')
print(f'Shape of the validation target: {target_val.shape}')
print(f'Shape of the test target: {target_test.shape}')
Shape of the data: (3471, 6767) Shape of the target: (3471, 1) Shape of the train data: (2082, 6767) Shape of the validation data: (694, 6767) Shape of the test data: (695, 6767) Shape of the train target: (2082, 1) Shape of the validation target: (694, 1) Shape of the test target: (695, 1)
Evaluation Metric¶
I choose the Accuracy and the Area Under The Curve (AUC) score as a evaluation metric.
The AUC is a common evaluation metric for binary classification, specifically tied to the Receiver Operating Characteristic (ROC) curve. The ROC curve plots the True Positive Rate (TPR) againste the False Positive Rate (FPR) at different threshold levels. The AUC score quantifies the overall ability of the model to distinguish between the positive and negative classes. An AUC of 0.5 suggests no discrimination capacitiy (equivalent to random guessing), while an AUC of 1.0 indicates perfect separation of the two classes.
We then compare the AUC score to the Accuracy score. Which calculates the number of correct predictions compared to the total number of predictions.
This is calculated for binary classification:
$Accuracy = \frac{TP + TN}{TP + TN + FP + FN}$
With:
- True Positive (TP)
- True Negative (TN)
- False Positive (FP)
- False Negative (FN)
Baseline - Most frequent Class¶
We will begin with a naive baseline model to establish a reference point for comparison. Given that our target classes are slightly imbalanced, we will calculate the most frequent class. The results indicate that the female class is the most prevalent.
From these findings, we observe that by randomly predicting the female class, we achieve an accuracy of 52% and an AUC of 50%. Looking at the AUC we can observe that the model has no discrimination capacity to distinguish between male and female subjects.
# Count unique values and their frequencies
unique, counts = np.unique(target_train, return_counts=True)
max_count_index = np.argmax(counts)
most_frequent = unique[max_count_index]
most_frequent_count = counts[max_count_index]
print(f"Most frequent Class: {most_frequent}, Count: {most_frequent_count}")
# Simple Heuristic
most_frequent_clf = DummyClassifier(strategy='most_frequent')
most_frequent_clf.fit(data_train, target_train)
most_frequent_test_pred = most_frequent_clf.predict(data_test)
most_frequent_val_pred = most_frequent_clf.predict(data_val)
most_frequent_train_pred = most_frequent_clf.predict(data_train)
# This function is used to store the results of the experiments in a table
def add_results(results, algorithm,
target_train, target_train_pred,
target_val, target_val_pred,
target_test, target_test_pred):
'''
Create a table with evaluation results
of a regression experiment
'''
for dataset, actual, predicted in zip(
("train", 'val', "test"),
(target_train, target_val, target_test),
(target_train_pred, target_val_pred, target_test_pred)):
results= pd.concat([results, pd.DataFrame([{
"algorithm": algorithm,
"dataset": dataset,
"Accuracy": round(accuracy_score(actual, predicted)*100, 3),
"AUC": round(roc_auc_score(actual, predicted)*100, 3),
}])], ignore_index=True)
return results
# Calculate accuracy
results =[]
results = add_results(pd.DataFrame(), "Baseline",
target_train=target_train,
target_train_pred=most_frequent_train_pred,
target_val=target_val,
target_val_pred=most_frequent_val_pred,
target_test=target_test,
target_test_pred=most_frequent_test_pred)
display(results)
Most frequent Class: 1, Count: 1077
| algorithm | dataset | Accuracy | AUC | |
|---|---|---|---|---|
| 0 | Baseline | train | 51.729 | 50.0 |
| 1 | Baseline | val | 51.729 | 50.0 |
| 2 | Baseline | test | 51.799 | 50.0 |
Baseline - SVM¶
Now we run a classical machine learning model which we can compare the deepl learning model with. We choose Support Vector Machine (SVM) as the second Baseline.
The AUC scores, ranging from 92.6 to 93.3, indicate that the SVM model has strong discriminatory ability, distinguishing well between the classes across different thresholds.
dt_clf = SVC()
dt_clf.fit(data_train, target_train.reshape(-1))
dt_test_pred = dt_clf.predict(data_test)
dt_val_pred = dt_clf.predict(data_val)
dt_train_pred = dt_clf.predict(data_train)
results = add_results(results, "SVM",
target_train=target_train,
target_train_pred=dt_train_pred,
target_val=target_val,
target_val_pred=dt_val_pred,
target_test=target_test,
target_test_pred=dt_test_pred)
display(results)
| algorithm | dataset | Accuracy | AUC | |
|---|---|---|---|---|
| 0 | Baseline | train | 51.729 | 50.000 |
| 1 | Baseline | val | 51.729 | 50.000 |
| 2 | Baseline | test | 51.799 | 50.000 |
| 3 | SVM | train | 92.988 | 92.886 |
| 4 | SVM | val | 92.795 | 92.617 |
| 5 | SVM | test | 93.381 | 93.300 |
Deep Learning Experiments¶
Necessary Steps:¶
Tensors:
torch.tensor is essential for leveraging GPU acceleration, automatic differentiation, and efficient operations in deep learning applications.
We need to convert the numpy arrays into torch.tensor with the dtype float32. This is nececarry to have faster computation time and also because the criterion BCEWithLogitsLoss for binary classifications can just handle float32.
# Create Tensors
data_train_tensor = torch.tensor(data_train, dtype=torch.float32)
data_test_tensor = torch.tensor(data_test, dtype=torch.float32)
data_val_tensor = torch.tensor(data_val, dtype=torch.float32)
target_train_tensor = torch.tensor(target_train, dtype=torch.float32)
target_test_tensor = torch.tensor(target_test, dtype=torch.float32)
target_val_tensor = torch.tensor(target_val, dtype=torch.float32)
# Visualize the shape
print(f'data_train_tensor shape: {data_train_tensor.shape}, type: {data_train_tensor.dtype}')
print(f'data_test_tensor shape: {data_test_tensor.shape}, type: {data_test_tensor.dtype}')
print(f'data_val_tensor shape: {data_val_tensor.shape}, type: {data_val_tensor.dtype}\n')
print(f'target_train_tensor shape: {target_train_tensor.shape}, type: {target_train_tensor.dtype}')
print(f'target_test_tensor shape: {target_test_tensor.shape}, type: {target_test_tensor.dtype}')
print(f'target_val_tensor shape: {target_val_tensor.shape}, type: {target_val_tensor.dtype}')
data_train_tensor shape: torch.Size([2082, 6767]), type: torch.float32 data_test_tensor shape: torch.Size([695, 6767]), type: torch.float32 data_val_tensor shape: torch.Size([694, 6767]), type: torch.float32 target_train_tensor shape: torch.Size([2082, 1]), type: torch.float32 target_test_tensor shape: torch.Size([695, 1]), type: torch.float32 target_val_tensor shape: torch.Size([694, 1]), type: torch.float32
Next we define a helper function which is extracting the age_data and gait_data from the combinded data. Because we have to treat them seperatly in the model.
def get_data(combined_data):
"""
Helper function to get the data into the right shape.
"""
combined_data = combined_data.reshape(len(combined_data),101, -1)
# bring the data into shape (age_data, joint_data, axis_data, gait_cycle_data)
original_gait_data = combined_data[:, :, :-1] # Shape (3916, 101, 66)
original_age_data = combined_data[:, 0, -1] # Shape (3916,)
# Reshape age_data to (batch_size, 1)
original_age_data = original_age_data.unsqueeze(1) # Add a dimension for features
return original_age_data, original_gait_data
Reproducible Experiment¶
To have a reproducible experiment with a Deep Learning Architecture we set the torch.manual_seed to 1
# randomstate for torch
torch.manual_seed(1)
<torch._C.Generator at 0x148a7188130>
Deep Learning Architecture¶
The architecture combines an LSTM model for sequential gait data with a fully connected layer for non-sequential age data.
Components of the Architecture¶
- LSTM for Gait Data:
Input Size: 66 features representing joint data. Hidden Size: 128, defining the dimensionality of LSTM outputs. Number of Layers: 2 stacked LSTM layers. Batch First: Set to True, to work with data in the format (batch_size, seq_length, input_size).
- Linear Layer for Age Data:
In Features: 1, as age data has a single column. Out Features: 16, capturing age-related patterns through 16 distinct nodes.
Forward Pass:¶
Concateniate the LSTM and Linear Layer:
To integrate gait and age data, we concatenate the final output from the LSTM with the output from the age data’s linear layer. This combined data then flows through three fully connected layers, each progressively reducing in size, ending with a single output neuron. This final layer output suits binary classification with BCEWithLogitsLoss, which expects logits as input for numerical stability.
Activation and Regularization: Each linear layer uses Leaky ReLU as the activation function, mitigating issues like the dying ReLU problem. Dropout is applied after each layer to prevent overfitting by randomly zeroing some neurons.
Prediction:¶
The model's prediction function applies the sigmoid activation to logits from the forward pass, producing probabilities between 0 and 1, which represent the model’s confidence in each class (e.g., male or female).
class GaitGenderClassifier(nn.Module):
def __init__(self, dropout_prob):
super(GaitGenderClassifier, self).__init__()
self.lstm = nn.LSTM(input_size=66, hidden_size=128, num_layers=2, batch_first=True)
self.age_fc = nn.Sequential(
nn.Linear(1, 16),
nn.LeakyReLU(),
nn.Dropout(p=dropout_prob)
)
self.fc1 = nn.Sequential(
nn.Linear(128 + 16, 64),
nn.LeakyReLU(),
nn.Dropout(p=dropout_prob)
)
self.fc2 = nn.Sequential(
nn.Linear(64, 32),
nn.LeakyReLU(),
nn.Dropout(p=dropout_prob)
)
self.fc3 = nn.Sequential(
nn.Linear(32, 8),
nn.LeakyReLU(),
nn.Dropout(p=dropout_prob)
)
self.output = nn.Linear(8, 1)
def forward(self, age, gait_cycle):
batch_size = gait_cycle.shape[0]
h_0 = torch.zeros(2, batch_size, 128).to(gait_cycle.device)
c_0 = torch.zeros(2, batch_size, 128).to(gait_cycle.device)
x, _ = self.lstm(gait_cycle, (h_0, c_0))
x = x[:, -1, :] # Take the output from the last time step
age_out = self.age_fc(age)
combined = torch.cat((x, age_out), dim=1)
x = self.fc1(combined)
x = self.fc2(x)
x = self.fc3(x)
output = self.output(x)
return output
def predict(self, age, gait_cycle, threshold=0.5):
self.eval() # Set model to evaluation mode
with torch.no_grad():
logits = self.forward(age, gait_cycle) # Get raw logits
probabilities = torch.sigmoid(logits) # Apply sigmoid to get probabilities
predictions = (probabilities >= threshold).float() # Apply threshold for binary classification
return predictions
Training:¶
Training evaluates the model based on accuracy and validation loss.
Device Management:¶
The model and data are moved to a GPU if available to accelerate computations.
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
X_train, y_train, X_val, y_val = X_train.to(device), y_train.to(device), X_val.to(device), y_val.to(device)
model.to(device)
Initialisation:¶
We initialize metrics to evaluate model performance, setting the evaluation metric to -∞ to prioritize improvement in accuracy. The best validation loss is set to +∞ since lower values indicate better performance. We save the initial model configuration as our baseline. Since early stopping criteria are implemented, patience counters are also initialized to 0.
best_eval_metric_val = -float('inf')
best_val_loss = float('inf')
best_model = copy.deepcopy(model.state_dict())
patience_counter_loss = 0
patience_counter_metric = 0
Data Loader:¶
Next, we create a DataLoader for the training dataset, enabling mini-batch processing. Setting shuffle=True ensures that data is shuffled at the start of each epoch to improve generalization.
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
To stabilize training, we define a scheduler that reduces the learning rate by a factor of 0.5 if validation loss fails to improve after a set number of epochs (patience).
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
Training Loop:¶
In the training loop, we begin by setting the model to training mode. We then iterate over the mini-batches, moving each batch of data and target tensors to the specified device. Gradients are cleared at the start of each batch to ensure that only the current batch's gradients affect the parameter update, aligning optimization with the loss for that specific batch.
A forward pass through the model computes the output, and the loss is calculated using BCEWithLogitsLoss, a function that combines Binary Cross-Entropy (BCE) loss and the sigmoid activation. This criterion is mathematically defined as:
$BCE = - (y \cdot log(p) + (1-y) \cdot log(1-p))$
where $y$ is the true label (0 for male, 1 for female) and $p$ is the predicted probability. BCEWithLogitsLoss expects logits (raw outputs) and internally applies a sigmoid to transform these into probabilities.
Backpropagation is then used to compute gradients, after which the optimizer updates the model weights. Finally, the loss for each epoch is accumulated.
for epoch in range(epochs):
model.train()
epoch_loss = 0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
outputs = model(*get_data(data))
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
Validation:¶
After training, the model is evaluated on the validation set. First, we switch the model to evaluation mode and use torch.no_grad() to disable gradient calculations, ensuring faster computation and reduced memory usage. We then compute the validation outputs and loss. The learning rate scheduler updates based on the validation loss. If the loss does not improve for 3 consecutive epochs, the learning rate is reduced by a factor of 0.5.
model.eval()
with torch.no_grad():
val_outputs = model(*get_data(X_val))
loss_val = criterion(val_outputs, y_val)
scheduler.step(loss_val)
Next, we then calculate the accuracies for both the training and validation data.
predictions_train = model.predict(*get_data(X_train))
accuracy_train = accuracy_score(y_train.cpu(), predictions_train.cpu())
predictions_val = model.predict(*get_data(X_val))
accuracy_val = accuracy_score(y_val.cpu(), predictions_val.cpu())
Then the accuracies and the losses are logged for choosing the dropout rate and for visualize the training process in Tensorboard.
if log_dropout:
train_losses.append(epoch_loss / len(train_loader))
val_losses.append(loss_val.item())
train_accuracies.append(accuracy_train)
val_accuracies.append(accuracy_val)
if log_tensorboard:
writer.add_scalars("Loss", {"train": epoch_loss / len(train_loader), 'val': loss_val}, epoch)
writer.add_scalars("Accuracy", {"train": accuracy_train, "val": accuracy_val}, epoch)
writer.flush()
Early Stopping¶
If there is no improvement in either the validation loss or validation accuracy, training will be halted, and the highest validation accuracy achieved will be saved in the best_model.
if loss_val < best_val_loss:
best_val_loss = loss_val
patience_counter_loss = 0
else:
patience_counter_loss += 1
if accuracy_val > best_eval_metric_val:
best_eval_metric_val = accuracy_val
best_model = copy.deepcopy(model.state_dict())
patience_counter_metric = 0
else:
patience_counter_metric += 1
if (patience_counter_loss >= patience) or (patience_counter_metric >= patience):
print(f"Early stopping at epoch {epoch+1}")
break
dropout_results = [] # storing the results of the dropout experiment
def training(model, X_train, y_train, X_val, y_val, criterion, optimizer, epochs, writer, batch_size=64, patience=50, dropout_prob=0.5, log_tensorboard=True, log_dropout=True, verbose=True):
# Set device to GPU if available
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# Move data and model to device
X_train, y_train, X_val, y_val = X_train.to(device), y_train.to(device), X_val.to(device), y_val.to(device)
model.to(device)
# Initialize best metrics and model
best_eval_metric_val = -float('inf') # Initialize the best_eval_metric
best_val_loss = float('inf') # Initialize the best_val_loss
best_model = copy.deepcopy(model.state_dict()) # Initialize the best_model
patience_counter_loss = 0 # Early stopping counter
patience_counter_metric = 0 # Early stopping counter
# DataLoader for batching
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
# Scheduler for learning rate adjustment (if val_loss does not improve for 'patience' epochs)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
train_losses, val_losses = [], [] # to visualize the drop_out_prob
train_accuracies, val_accuracies = [], []
# Training Loop
for epoch in range(epochs):
model.train()
epoch_loss = 0
# Training loop
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad() # Clear gradients
outputs = model(*get_data(data)) # Forward pass
loss = criterion(outputs, target) # Calculate loss
loss.backward() # Backpropagation
optimizer.step() # Update weights
epoch_loss += loss.item() # Accumulate batch loss
# Validation
model.eval()
with torch.no_grad():
val_outputs = model(*get_data(X_val)) # Forward pass
loss_val = criterion(val_outputs, y_val)
scheduler.step(loss_val) # Adjust learning rate based on loss
# calculate the training accuracy
predictions_train = model.predict(*get_data(X_train))
accuracy_train = accuracy_score(y_train.cpu(), predictions_train.cpu())
# Calculate validation accuracy
predictions_val = model.predict(*get_data(X_val))
accuracy_val = accuracy_score(y_val.cpu(), predictions_val.cpu())
# Log to dropout_results
if log_dropout:
train_losses.append(epoch_loss / len(train_loader))
val_losses.append(loss_val.item())
train_accuracies.append(accuracy_train)
val_accuracies.append(accuracy_val)
# Log to TensorBoard
if log_tensorboard:
writer.add_scalars("Loss", {"train": epoch_loss / len(train_loader), 'val': loss_val}, epoch)
writer.add_scalars("Accuracy", {"train": accuracy_train, "val": accuracy_val}, epoch)
#writer.add_scalar("Accuracy/val", accuracy_val, epoch)
writer.flush()
# Early stopping if loss_val is increasing
if loss_val < best_val_loss:
best_val_loss = loss_val # Update best val_loss
patience_counter_loss = 0 # Reset patience counter
else:
patience_counter_loss += 1 # Increment if no improvement
# Early Stopping based on if val_acc is not increasing
if accuracy_val > best_eval_metric_val:
best_eval_metric_val = accuracy_val
best_model = copy.deepcopy(model.state_dict()) # saves the best model where the accuracy_val is highest
patience_counter_metric = 0 # Reset patience counter if improved
else:
patience_counter_metric += 1
# Early stopping check
if (patience_counter_loss >= patience) or (patience_counter_metric >= patience):
print(f"Early stopping at epoch {epoch+1}")
break
# Print status
if verbose:
if epoch % 19 == 0:
print(f"| Epoch {epoch+1} | Train Loss: {epoch_loss / len(train_loader):.4f}, Validation Loss: {loss_val:.4f} | Train Accuracy: {accuracy_train:.4f}, Val Accuracy: {accuracy_val:.4f} |")
dropout_results.append({'dropout_prob': dropout_prob,
'train_loss': train_losses,
'val_loss': val_losses,
'training_acc': train_accuracies,
'validation_acc': val_accuracies
})
# Load the best model
model.load_state_dict(best_model)
print(f"Best validation accuracy: {best_eval_metric_val:.4f}")
writer.close()
return model
Finding Dropout Probabilities¶
epochs=10000
learning_rate = 0.0001 # highest val_score at learning_rate 0.0001
dropout_probs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
for dropout_prob in dropout_probs:
model = GaitGenderClassifier(dropout_prob)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss() # For binary classification without sigmoid in final layer
writer = SummaryWriter()
# Training the model
trained_model = training(model=model,
X_train=data_train_tensor,
y_train=target_train_tensor,
X_val=data_val_tensor,
y_val=target_val_tensor,
criterion=criterion,
optimizer=optimizer,
epochs=epochs,
writer=writer,
dropout_prob=dropout_prob,
log_tensorboard=False,
log_dropout=True,
verbose=True)
| Epoch 1 | Train Loss: 0.6912, Validation Loss: 0.6820 | Train Accuracy: 0.6206, Val Accuracy: 0.6398 | | Epoch 20 | Train Loss: 0.1779, Validation Loss: 0.2766 | Train Accuracy: 0.9673, Val Accuracy: 0.9092 | | Epoch 39 | Train Loss: 0.0437, Validation Loss: 0.1847 | Train Accuracy: 0.9952, Val Accuracy: 0.9539 | | Epoch 58 | Train Loss: 0.0169, Validation Loss: 0.2077 | Train Accuracy: 0.9986, Val Accuracy: 0.9539 | | Epoch 77 | Train Loss: 0.0200, Validation Loss: 0.2091 | Train Accuracy: 0.9986, Val Accuracy: 0.9553 | Early stopping at epoch 89 Best validation accuracy: 0.9597 | Epoch 1 | Train Loss: 0.7115, Validation Loss: 0.6877 | Train Accuracy: 0.5173, Val Accuracy: 0.5173 | | Epoch 20 | Train Loss: 0.3020, Validation Loss: 0.2905 | Train Accuracy: 0.9270, Val Accuracy: 0.8890 | | Epoch 39 | Train Loss: 0.0353, Validation Loss: 0.1808 | Train Accuracy: 0.9976, Val Accuracy: 0.9640 | | Epoch 58 | Train Loss: 0.0306, Validation Loss: 0.1657 | Train Accuracy: 0.9976, Val Accuracy: 0.9697 | | Epoch 77 | Train Loss: 0.0317, Validation Loss: 0.1675 | Train Accuracy: 0.9976, Val Accuracy: 0.9697 | Early stopping at epoch 82 Best validation accuracy: 0.9726 | Epoch 1 | Train Loss: 0.7351, Validation Loss: 0.6950 | Train Accuracy: 0.4827, Val Accuracy: 0.4827 | | Epoch 20 | Train Loss: 0.4086, Validation Loss: 0.3200 | Train Accuracy: 0.9092, Val Accuracy: 0.8775 | | Epoch 39 | Train Loss: 0.0848, Validation Loss: 0.2648 | Train Accuracy: 0.9837, Val Accuracy: 0.9380 | | Epoch 58 | Train Loss: 0.0352, Validation Loss: 0.2207 | Train Accuracy: 0.9976, Val Accuracy: 0.9524 | | Epoch 77 | Train Loss: 0.0333, Validation Loss: 0.2316 | Train Accuracy: 0.9990, Val Accuracy: 0.9524 | Early stopping at epoch 87 Best validation accuracy: 0.9568 | Epoch 1 | Train Loss: 0.7926, Validation Loss: 0.7107 | Train Accuracy: 0.4827, Val Accuracy: 0.4827 | | Epoch 20 | Train Loss: 0.6237, Validation Loss: 0.5537 | Train Accuracy: 0.8650, Val Accuracy: 0.8487 | | Epoch 39 | Train Loss: 0.2363, Validation Loss: 0.1486 | Train Accuracy: 0.9731, Val Accuracy: 0.9568 | | Epoch 58 | Train Loss: 0.0527, Validation Loss: 0.1223 | Train Accuracy: 0.9995, Val Accuracy: 0.9798 | | Epoch 77 | Train Loss: 0.0501, Validation Loss: 0.1300 | Train Accuracy: 1.0000, Val Accuracy: 0.9827 | | Epoch 96 | Train Loss: 0.0457, Validation Loss: 0.1314 | Train Accuracy: 1.0000, Val Accuracy: 0.9827 | Early stopping at epoch 102 Best validation accuracy: 0.9827 | Epoch 1 | Train Loss: 0.7440, Validation Loss: 0.6920 | Train Accuracy: 0.5173, Val Accuracy: 0.5173 | | Epoch 20 | Train Loss: 0.6976, Validation Loss: 0.6998 | Train Accuracy: 0.5173, Val Accuracy: 0.5173 | | Epoch 39 | Train Loss: 0.6938, Validation Loss: 0.6996 | Train Accuracy: 0.5173, Val Accuracy: 0.5173 | Early stopping at epoch 51 Best validation accuracy: 0.5173 | Epoch 1 | Train Loss: 0.9118, Validation Loss: 0.7279 | Train Accuracy: 0.4827, Val Accuracy: 0.4827 | | Epoch 20 | Train Loss: 0.7098, Validation Loss: 0.6961 | Train Accuracy: 0.4827, Val Accuracy: 0.4827 | | Epoch 39 | Train Loss: 0.6993, Validation Loss: 0.6947 | Train Accuracy: 0.4827, Val Accuracy: 0.4827 | Early stopping at epoch 51 Best validation accuracy: 0.4827 | Epoch 1 | Train Loss: 1.1165, Validation Loss: 0.7016 | Train Accuracy: 0.4827, Val Accuracy: 0.4827 | | Epoch 20 | Train Loss: 0.7283, Validation Loss: 0.6966 | Train Accuracy: 0.4827, Val Accuracy: 0.4827 | | Epoch 39 | Train Loss: 0.7046, Validation Loss: 0.6953 | Train Accuracy: 0.4827, Val Accuracy: 0.4827 | Early stopping at epoch 51 Best validation accuracy: 0.4827 | Epoch 1 | Train Loss: 1.8160, Validation Loss: 0.6979 | Train Accuracy: 0.5173, Val Accuracy: 0.5173 | | Epoch 20 | Train Loss: 1.0527, Validation Loss: 0.6950 | Train Accuracy: 0.5173, Val Accuracy: 0.5173 | | Epoch 39 | Train Loss: 0.8591, Validation Loss: 0.6945 | Train Accuracy: 0.5173, Val Accuracy: 0.5173 | Early stopping at epoch 51 Best validation accuracy: 0.5173
Visualizing the dropout probabilities suggests that a probability of 0.4 yields the highest validation accuracy while minimizing overfitting.
- p=0.1, 0.2: At about 82% we reach overfitting
- p=0.3, 0.4: At about 92% we reach overfitting
- p=0.5 - 0.8: Unusefull results.
fig, axis = plt.subplots(1, 2, figsize=(15, 10))
# Define the color map
colors = cm.RdYlGn(np.linspace(0, 1, len(dropout_results)))
# Plotting the results with the Spectral color palette
for idx, result in enumerate(dropout_results):
axis[0].plot(result['train_loss'], label=f"train drop={result['dropout_prob']}", linestyle='--', color=colors[idx])
axis[0].plot(result['val_loss'], label=f"val drop={result['dropout_prob']}", color=colors[idx])
axis[1].plot(result['validation_acc'], label=f"val drop={result['dropout_prob']}", color=colors[idx])
axis[1].plot(result['training_acc'], label=f"train drop={result['dropout_prob']}", linestyle='--', color=colors[idx])
# Configure axis 0 (Loss plot)
axis[0].set_xlabel("Epoch")
axis[0].set_ylabel("Loss")
axis[0].set_title("Loss vs Dropout Probability")
axis[0].legend()
# Configure axis 1 (Accuracy plot)
axis[1].set_xlabel("Epoch")
axis[1].set_ylabel("Accuracy")
axis[1].set_title("Accuracy vs Dropout Probability")
axis[1].legend()
plt.show()
Final Model - Tracking with Tensorboard¶
# %load_ext tensorboard
# %tensorboard --logdir runs
#%tensorboard --logdir=runs/ --host localhost --port 8088 # http://localhost:8088
dropout_prob = 0.4
# Instantiate model, criterion, optimizer, and TensorBoard writer
model = GaitGenderClassifier(dropout_prob)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss() # For binary classification without sigmoid in final layer
writer = SummaryWriter()
# Training the model
trained_model = training(model=model,
X_train=data_train_tensor,
y_train=target_train_tensor,
X_val=data_val_tensor,
y_val=target_val_tensor,
criterion=criterion,
optimizer=optimizer,
epochs=epochs,
writer=writer,
dropout_prob=dropout_prob,
log_tensorboard=True,
log_dropout=False,
verbose=True)
| Epoch 1 | Train Loss: 0.7285, Validation Loss: 0.6999 | Train Accuracy: 0.4827, Val Accuracy: 0.4827 | | Epoch 20 | Train Loss: 0.6001, Validation Loss: 0.5522 | Train Accuracy: 0.8622, Val Accuracy: 0.8401 | | Epoch 39 | Train Loss: 0.2803, Validation Loss: 0.1905 | Train Accuracy: 0.9616, Val Accuracy: 0.9452 | | Epoch 58 | Train Loss: 0.1748, Validation Loss: 0.1707 | Train Accuracy: 0.9798, Val Accuracy: 0.9539 | | Epoch 77 | Train Loss: 0.1394, Validation Loss: 0.1614 | Train Accuracy: 0.9870, Val Accuracy: 0.9582 | | Epoch 96 | Train Loss: 0.1163, Validation Loss: 0.1624 | Train Accuracy: 0.9870, Val Accuracy: 0.9582 | Early stopping at epoch 111 Best validation accuracy: 0.9597
from PIL import Image
def plot_image(name):
img = Image.open(name)
plt.imshow(img)
plt.axis('off')
plt.show()
try:
plot_image("/content/drive/MyDrive/HochschuleKiel/DeepLearning/Portfolio Exam/code/Tensorbaord - Loss.png")
except:
plot_image("Tensorbaord - Loss.png")
try:
plot_image("/content/drive/MyDrive/HochschuleKiel/DeepLearning/Portfolio Exam/code/Tensorboard - Accuracy.png")
except:
plot_image("Tensorboard - Accuracy.png")
Evaluate the final Model¶
After the training we bring all the tensors back on cpu.
# bring tensor back to cpu
data_train_tensor = data_train_tensor.to('cpu')
data_test_tensor = data_test_tensor.to('cpu')
data_val_tensor = data_val_tensor.to('cpu')
target_train_tensor = target_train_tensor.to('cpu')
target_test_tensor = target_test_tensor.to('cpu')
target_val_tensor = target_val_tensor.to('cpu')
# also the model
model.to('cpu');
Next we evaluate the performance of the Deep Learning Model compared to the SVM and the Baseline.
results = add_results(results, "Deep Learning",
target_train_tensor,
model.predict(*get_data(data_train_tensor)),
target_val_tensor,
model.predict(*get_data(data_val_tensor)),
target_test_tensor,
model.predict(*get_data(data_test_tensor)))
display(results)
| algorithm | dataset | Accuracy | AUC | |
|---|---|---|---|---|
| 0 | Baseline | train | 51.729 | 50.000 |
| 1 | Baseline | val | 51.729 | 50.000 |
| 2 | Baseline | test | 51.799 | 50.000 |
| 3 | SVM | train | 92.988 | 92.886 |
| 4 | SVM | val | 92.795 | 92.617 |
| 5 | SVM | test | 93.381 | 93.300 |
| 6 | Deep Learning | train | 98.223 | 98.216 |
| 7 | Deep Learning | val | 95.965 | 95.921 |
| 8 | Deep Learning | test | 95.971 | 95.956 |
The Deep Learning model shows very high training accuracy and AUC, suggesting it effectively distinguishes between classes, though it may be slightly overfitted. On validation and test sets, its accuracy and AUC drop slightly, showing strong generalization. It outperforms the SVM across all datasets, with a 3–4% advantage on the test set, indicating a better ability to capture complex patterns.
Conclusions and Future Work¶
Conclusion¶
This project successfully demonstrates that gender classification based on gait cycle data is feasible, even with a relatively small sample size, and shows promising results in distinguishing between male and female subjects based on their individual gait cycles. The SVM model achieved strong performance, with around 93% accuracy and AUC on test data, while the Deep Learning model went further, achieving over 98% on training accuracy and around 96% accuracy and AUC on validation and test sets. These findings highlight the potential of deep learning in capturing complex, gender-related patterns in gait data with high reliability.
Potential Limitations:
- Sample Size and Diversity: With a limited number of subjects, there is a risk that the model has learned subject-specific patterns rather than generalized gender differences. A larger and more diverse dataset would improve generalization, focusing more on gender-related patterns rather than individual gait idiosyncrasies.
- Model Complexity: In this proof of concept, a simple neural network was used. Expanding data and resources will be crucial to build more sophisticated models with improved accuracy and reduced overfitting.
Future Work¶
The promising results of this study suggest several directions for further research:
Data Collection: Expanding the dataset to include more subjects with varying ages and body types will enhance the model’s robustness and generalizability. A potential approach to improve generalization could be to split the training, validation, and test data based on unique subjects. For example, we could assign 50% of the subjects to the training set, 25% to the validation set, and 25% to the test set. To ensure balanced representation, it’s essential to have a roughly equal number of male and female subjects in each split.
Model Optimization: To address overfitting and improve performance, we could refine the deep learning model by exploring other architectures, such as GRU-based RNNs, which are efficient and less prone to overfitting, especially with smaller datasets. Alternatively, Conv1D-based CNNs may be effective for handling data with (N, features, gait cycles) shapes. Using hyperparameter optimization tools like Optuna could also help identify the best configurations. Stabalize the model by implementing Normalizing Layers into the model could be evaluated. Since the SVM machine learning algorithm achieved over 90% accuracy without fine-tuning, applying hyperparameter optimization could potentially push its accuracy even higher. This suggests that deep learning might not be necessary to achieve excellent performance. Feature reduction using e.g. PCA should also be explored. Since we have correlations between our features, this could lead to a strong increase in efficiency and decrease of computing time.
Feature Exploration: Including features like speed and acceleration of each joint could uncover additional gait patterns, adding depth to the model's understanding of gender differences.
Clinical Applications: Testing the model in clinical trials could validate its potential for personalized assessments in therapeutic settings based on gait characteristics.
Future Model Applications: Developing models that can classify individual subjects or detect injuries could open pathways for various applications beyond gender classification.
This study establishes a valuable foundation for future work in gait analysis and gender classification, with potential applications in clinical diagnostics, sports science, and security, highlighting the broader possibilities of gait-based assessment and monitoring.